From 7c1f2d1b19225118cbf51353995e56bff1b0c4fa Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 27 Mar 2026 07:44:37 +0000 Subject: [PATCH 01/41] add neo_chat and neo_chat_moe --- lightllm/models/__init__.py | 2 + lightllm/models/llama/model.py | 49 +- lightllm/models/neo_chat/__init__.py | 0 .../models/neo_chat/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 117 +++++ .../models/neo_chat/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 23 + .../layer_weights/transformer_layer_weight.py | 57 +++ lightllm/models/neo_chat/model.py | 53 +++ lightllm/models/neo_chat_moe/__init__.py | 0 lightllm/models/neo_chat_moe/infer_struct.py | 103 +++++ .../neo_chat_moe/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 117 +++++ .../neo_chat_moe/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 23 + .../layer_weights/transformer_layer_weight.py | 48 ++ lightllm/models/neo_chat_moe/model.py | 192 ++++++++ lightllm/models/neo_chat_moe/neo_visual.py | 281 ++++++++++++ .../neo_chat_moe/triton_kernel/__init__.py | 0 .../context_attention_fwd_neo.py | 430 ++++++++++++++++++ .../triton_kernel/get_neo_position.py | 191 ++++++++ .../models/neo_chat_moe/vision_process.py | 141 ++++++ .../visualserver/model_infer/model_rpc.py | 3 + 23 files changed, 1826 insertions(+), 4 deletions(-) create mode 100644 lightllm/models/neo_chat/__init__.py create mode 100644 lightllm/models/neo_chat/layer_infer/__init__.py create mode 100644 lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/neo_chat/layer_weights/__init__.py create mode 100644 lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/neo_chat/model.py create mode 100644 lightllm/models/neo_chat_moe/__init__.py create mode 100644 lightllm/models/neo_chat_moe/infer_struct.py create mode 100644 lightllm/models/neo_chat_moe/layer_infer/__init__.py create mode 100644 lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/neo_chat_moe/layer_weights/__init__.py create mode 100644 lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/neo_chat_moe/model.py create mode 100644 lightllm/models/neo_chat_moe/neo_visual.py create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/__init__.py create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py create mode 100644 lightllm/models/neo_chat_moe/vision_process.py diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index f2e29d4a88..a24cbf140a 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -39,4 +39,6 @@ ) from lightllm.models.gpt_oss.model import GptOssTpPartModel from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel +from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel +from lightllm.models.neo_chat.model import NeoTpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index c104ebccc9..1c0277e59b 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -73,10 +73,8 @@ def _init_custom(self): """ rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: - self._init_to_get_rotary() - return - - if "rope_type" in rope_scaling: + scaling_type = "default" + elif "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] elif "type" in rope_scaling: scaling_type = rope_scaling["type"] @@ -96,6 +94,8 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() return def _init_to_get_rotary(self, default_base=10000): @@ -301,3 +301,44 @@ def _init_to_get_llama3_rotary(self, default_base=10000): self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return + + def _init_to_get_hw_rotary(self, default_base=10000): + partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) + if self.config.get("rope_scaling", {}) is None: + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) + + base = self.config.get("rope_theta_hw", float(default_base)) + if "max_sequence_length" in self.config: + max_seq_len = self.config["max_sequence_length"] + else: + max_position_embeddings = self.config.get( + "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 + ) + max_seq_len = max_position_embeddings * rope_scaling_factor + + # NTK + try: + ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula + except: + pass + + full_inv_freq = 1.0 / ( + base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) + ) + inv_freq = full_inv_freq[::2] + t = ( + torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) + / rope_scaling_factor + ) + freqs = torch.outer(t, inv_freq) + + self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() + return diff --git a/lightllm/models/neo_chat/__init__.py b/lightllm/models/neo_chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat/layer_infer/__init__.py b/lightllm/models/neo_chat/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..ec181a0b8d --- /dev/null +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -0,0 +1,117 @@ +import torch +from functools import partial +from typing import Tuple +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd +from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight +from lightllm.distributed import all_reduce +import torch.distributed as dist +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.common.basemodel.attention.base_att import AttControl + + +class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + return + + def _bind_attention(self): + self._context_attention_kernel = self._context_attention_kernel + self._token_attention_kernel = self._token_attention_kernel + return + + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight): + input = input.view(-1, self.embed_dim_) + + qkv = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) + q = q.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) + q_t, q_hw = q.chunk(2, dim=-1) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + k = cache_kv[:, : self.tp_k_head_num_, :] + v = cache_kv[:, self.tp_k_head_num_ :, :] + k_t, k_hw = k.chunk(2, dim=-1) + + q_t_2d = q_t.reshape(q.shape[0], -1) + q_hw_2d = q_hw.reshape(q.shape[0], -1) + k_t_2d = k_t.reshape(k.shape[0], -1) + k_hw_2d = k_hw.reshape(k.shape[0], -1) + layer_weight.qk_norm_weight_(q_t_2d, k_t_2d, eps=self.eps_) + layer_weight.qk_hw_norm_weight_(q_hw_2d, k_hw_2d, eps=self.eps_) + + q_t = q_t_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_hw = q_hw_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_h, q_w = q_hw.chunk(2, dim=-1) + + k_t = k_t_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_hw = k_hw_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_h, k_w = k_hw.chunk(2, dim=-1) + + rotary_emb_fwd( + q_t, + k_t, + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q = torch.cat([q_t, q_h, q_w], dim=-1) + q = q.reshape(q.shape[0], -1) + + k = torch.cat([k_t, k_h, k_w], dim=-1) + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] + infer_state.b_req_idx, + infer_state.b_q_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_q_seq_len, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_image_token_tag, + ) + return o_tensor + + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: NeoChatInferStateInfo, + layer_weight: NeoChatTransformerLayerWeight, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + att_control = AttControl() + # att_control.mla_decode_dict["softmax_scale"] = 1.0 / (self.head_dim_ ** 0.5) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, k=_k, v=_v, att_control=att_control, alloc_func=self.alloc_tensor + ) + o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_)[:, :, : self.head_dim_].contiguous() + return o_tensor diff --git a/lightllm/models/neo_chat/layer_weights/__init__.py b/lightllm/models/neo_chat/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..e6489f39af --- /dev/null +++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import torch +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "language_model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "")] = weights.pop(k) + + +class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..8351369fd8 --- /dev/null +++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py @@ -0,0 +1,57 @@ +from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + QKRMSNORMWeight, + RMSNormWeight, + QKVROWNMMWeight, +) + + +class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _init_weight_names(self): + super()._init_weight_names() + self._q_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_hw.weight" + self._k_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_hw.weight" + + def _init_qkv(self): + in_dim = self.n_embed + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + + def _init_norm(self): + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + + self.qk_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + q_weight_name=self._q_norm_name, + k_weight_name=self._k_norm_name, + data_type=self.data_type_, + ) + + self.qk_hw_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + q_weight_name=self._q_norm_hw_name, + k_weight_name=self._k_norm_hw_name, + data_type=self.data_type_, + ) diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py new file mode 100644 index 0000000000..14d9f96dc7 --- /dev/null +++ b/lightllm/models/neo_chat/model.py @@ -0,0 +1,53 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry, llm_model_type_is +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.server.core.objs import SamplingParams +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from lightllm.models.neo_chat_moe.vision_process import smart_resize +from lightllm.models.internvl.model import InternvlTokenizer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatTransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight +from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo + + +@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3")) +class NeoTpPartModel(Qwen3TpPartModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = NeoChatTransformerLayerInfer + + pre_and_post_weight_class = NeoChatPreAndPostLayerWeight + transformer_weight_class = NeoChatTransformerLayerWeight + + infer_state_class = NeoChatInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/neo_chat_moe/__init__.py b/lightllm/models/neo_chat_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py new file mode 100644 index 0000000000..add8abda08 --- /dev/null +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -0,0 +1,103 @@ +from typing import Optional, List +import torch +import numpy as np +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.common.req_manager import ReqManager +from lightllm.models.neo_chat_moe.triton_kernel.get_neo_position import get_neo_position_triton +from lightllm.models.llama.model import LlamaTpPartModel + + +class NeoChatInferStateInfo(LlamaInferStateInfo): + def __init__(self): + super().__init__() + self.position_cos = None + self.position_sin = None + self.position_cos_h = None + self.position_sin_h = None + self.position_cos_w = None + self.position_sin_w = None + + def init_some_extra_state(self, model: LlamaTpPartModel): + LlamaInferStateInfo.init_some_extra_state(self, model) + if self.is_prefill: + self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda( + non_blocking=True + ) + self.position_ids = self.get_neo_position(self.multimodal_params) + else: + b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] + for batch_idx, p in enumerate(self.multimodal_params): + position_delta = 0 + for image in p["images"]: + position_delta += image["grid_thwd"][3] + b_position_delta[batch_idx] = position_delta + position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device) + self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone() + self.position_ids[1:].zero_() + + self.position_ids = self.position_ids.contiguous() + self.position_cos = model._cos_cached[self.position_ids[0]] + self.position_sin = model._sin_cached[self.position_ids[0]] + self.position_cos_h = model._hw_cos_cached[self.position_ids[1]] + self.position_sin_h = model._hw_sin_cached[self.position_ids[1]] + self.position_cos_w = model._hw_cos_cached[self.position_ids[2]] + self.position_sin_w = model._hw_sin_cached[self.position_ids[2]] + return + + def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: + if len(multimodal_params) == 0: + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids + b_image_start_idx = [] + b_image_nums = [] + b_image_start_num = [] + b_image_len = [] + image_start_num = 0 + b_image_thwd = [] + + # pad multimodal_params to batch size. + batch_size = self.b_q_seq_len.shape[0] + multimodal_params = multimodal_params + [ + {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params)) + ] + + for _, p in enumerate(multimodal_params): + images = p.get("images", []) + for img in images: + b_image_start_idx.append(img["start_idx"]) + # a = img["start_idx"] + # print(f"img start_idx: {a}") + b_image_len.append(img["token_num"]) + b_image_thwd.append(img["grid_thwd"]) + b_image_nums.append(len(images)) + b_image_start_num.append(image_start_num) + image_start_num += len(images) + + # 没有任何图片 + if image_start_num == 0: + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + return position_ids.contiguous() + b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True) + b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4 + b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True) + b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True) + b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True) + + position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0))) + position_ids[0].copy_(self.position_ids) + + get_neo_position_triton( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + b_ready_cache_len=self.b_ready_cache_len, + b_q_seq_len=self.b_q_seq_len, + b_start_loc=self.b_q_start_loc, + b_image_token_tag=self.b_image_token_tag, + ) + return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/__init__.py b/lightllm/models/neo_chat_moe/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..4c4d8a22ab --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -0,0 +1,117 @@ +import torch +from functools import partial +from typing import Tuple +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo +from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.distributed import all_reduce +import torch.distributed as dist +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.common.basemodel.attention.base_att import AttControl + + +class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + return + + def _bind_attention(self): + self._context_attention_kernel = self._context_attention_kernel + self._token_attention_kernel = self._token_attention_kernel + return + + def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight): + input = input.view(-1, self.embed_dim_) + + qkv = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) + q = q.view(q.shape[0], self.tp_q_head_num_, self.head_dim_) + q_t, q_hw = q.chunk(2, dim=-1) + + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + k = cache_kv[:, : self.tp_k_head_num_, :] + v = cache_kv[:, self.tp_k_head_num_ :, :] + k_t, k_hw = k.chunk(2, dim=-1) + + q_t_2d = q_t.reshape(q.shape[0], -1) + q_hw_2d = q_hw.reshape(q.shape[0], -1) + k_t_2d = k_t.reshape(k.shape[0], -1) + k_hw_2d = k_hw.reshape(k.shape[0], -1) + layer_weight.qk_norm_weight_(q_t_2d, k_t_2d, eps=self.eps_) + layer_weight.qk_hw_norm_weight_(q_hw_2d, k_hw_2d, eps=self.eps_) + + q_t = q_t_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_hw = q_hw_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2) + q_h, q_w = q_hw.chunk(2, dim=-1) + + k_t = k_t_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_hw = k_hw_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2) + k_h, k_w = k_hw.chunk(2, dim=-1) + + rotary_emb_fwd( + q_t, + k_t, + infer_state.position_cos, + infer_state.position_sin, + ) + rotary_emb_fwd( + q_h, + k_h, + infer_state.position_cos_h, + infer_state.position_sin_h, + ) + rotary_emb_fwd( + q_w, + k_w, + infer_state.position_cos_w, + infer_state.position_sin_w, + ) + + q = torch.cat([q_t, q_h, q_w], dim=-1) + q = q.reshape(q.shape[0], -1) + + k = torch.cat([k_t, k_h, k_w], dim=-1) + cache_kv = torch.cat([k, v], dim=1) + return q, cache_kv + + def _context_attention_kernel( + self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] + infer_state.b_req_idx, + infer_state.b_q_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_q_seq_len, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_image_token_tag, + ) + return o_tensor + + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: NeoChatInferStateInfo, + layer_weight: NeoChatMOETransformerLayerWeight, + ) -> torch.Tensor: + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + att_control = AttControl() + # att_control.mla_decode_dict["softmax_scale"] = 1.0 / (self.head_dim_ ** 0.5) + o_tensor = infer_state.decode_att_state.decode_att( + q=_q, k=_k, v=_v, att_control=att_control, alloc_func=self.alloc_tensor + ) + o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_)[:, :, : self.head_dim_].contiguous() + return o_tensor diff --git a/lightllm/models/neo_chat_moe/layer_weights/__init__.py b/lightllm/models/neo_chat_moe/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..4b0eae91c3 --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,23 @@ +import torch +import numpy as np +from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight + +# add key: language_model.xxx -> xxx +# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now +def rename_weight_keys(weights): + prefix = "language_model." + keys = list(weights.keys()) + for k in keys: + if prefix in k: + weights[k.replace(prefix, "")] = weights.pop(k) + + +class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + return + + def load_hf_weights(self, weights): + rename_weight_keys(weights) + super().load_hf_weights(weights) + return diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..d4f985db45 --- /dev/null +++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py @@ -0,0 +1,48 @@ +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + QKRMSNORMWeight, + ROWMMWeight, + RMSNormWeight, +) + + +class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + self._is_merge_kv = network_config.get("merge_kv", True) + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _init_weight_names(self): + super()._init_weight_names() + self._q_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_hw.weight" + self._k_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_hw.weight" + + def _init_qkv(self): + super()._init_qkv() + + def _init_norm(self): + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + + self.qk_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + q_weight_name=self._q_norm_name, + k_weight_name=self._k_norm_name, + data_type=self.data_type_, + ) + + self.qk_hw_norm_weight_ = QKRMSNORMWeight( + dim=self.head_dim // 2, + q_weight_name=self._q_norm_hw_name, + k_weight_name=self._k_norm_hw_name, + data_type=self.data_type_, + ) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py new file mode 100644 index 0000000000..d9f40d7feb --- /dev/null +++ b/lightllm/models/neo_chat_moe/model.py @@ -0,0 +1,192 @@ +import os +import json +from lightllm.common.build_utils import repair_config +from lightllm.models.registry import ModelRegistry, llm_model_type_is +from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer +from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight +from lightllm.models.qwen2_vl.model import QWen2VLTokenizer +from lightllm.models.qwen3.model import Qwen3TpPartModel +from lightllm.server.core.objs import SamplingParams +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem +from lightllm.models.neo_chat_moe.vision_process import smart_resize +from lightllm.models.internvl.model import InternvlTokenizer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.models.neo_chat_moe.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight +from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo + +IMG_START_TOKEN = "" +IMG_END_TOKEN = "" +IMG_TOKEN = "" +AUDIO_START_TOKEN = "" + + +class NeoChatTokenizer(BaseMultiModalTokenizer): + def __init__(self, tokenizer, model_cfg, **kwargs): + super().__init__(tokenizer) + self.tokenizer = tokenizer + self.min_pixel = model_cfg.get("vision_config").get("min_pixels") + self.max_pixel = model_cfg.get("vision_config").get("max_pixels") + self.patch_size = model_cfg.get("vision_config").get("patch_size") + self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio") + + self.image_token_id = model_cfg.get("image_token_id") + self.image_start_tag = IMG_START_TOKEN + self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) + self.image_end_tag = IMG_END_TOKEN + self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) + self.conversation_module = self.load_conversion_module(tokenizer.name_or_path) + self.template = model_cfg.get("template", "neo1_0") + + def load_conversion_module(self, model_dir: str): + import importlib + + conversion_path = os.path.join(model_dir, "conversation.py") + if not os.path.exists(conversion_path): + return None + + spec = importlib.util.spec_from_file_location("conversation", str(conversion_path)) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) # must run the module + return module + + def init_imageitem_extral_params( + self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + img.extra_params["min_pixels"] = ( + sampling_params.min_pixels if sampling_params.min_pixels > 0 else self.min_pixel + ) + img.extra_params["max_pixels"] = ( + sampling_params.max_pixels if sampling_params.max_pixels > 0 else self.max_pixel + ) + assert ( + img.extra_params["min_pixels"] <= img.extra_params["max_pixels"] + ), "min_pixels should be less than or equal to max_pixels" + return + + def init_audioitem_extral_params( + self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams + ): + raise NotImplementedError + + def get_audio_token_length(self, audio: AudioItem): + raise NotImplementedError + + def get_image_token_length(self, img: ImageItem): + width, height = img.image_w, img.image_h + resized_height, resized_width = smart_resize( + height=height, + width=width, + factor=int(self.patch_size // self.downsample_ratio), + min_pixels=img.extra_params["min_pixels"], + max_pixels=img.extra_params["max_pixels"], + ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2)) + # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码 + img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num) + return token_num + + # only change the impl of the encode func: + def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs): + # TEXTTEXTTEXT --> TEXTTEXTTEXT + image_tokens = IMG_START_TOKEN + IMG_END_TOKEN + if multimodal_params is None: + add_special_tokens = kwargs.get("add_special_tokens", True) + return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + image_count = len(multimodal_params.images) + if not kwargs.get("already_tokenized", False): + prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count) + origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"]) + else: + origin_ids = prompt + # --> id,id+1...id+num + input_ids = [] + image_id = 0 + start_idx = 0 + while True: + try: + start_idx = origin_ids.index(self.image_start_id) + if start_idx + 1 >= len(origin_ids): + break + if origin_ids[start_idx + 1] == self.image_end_id: + input_ids.extend(origin_ids[: start_idx + 1]) + token_id = multimodal_params.images[image_id].token_id + token_num = multimodal_params.images[image_id].token_num + multimodal_params.images[image_id].start_idx = len(input_ids) + input_ids.extend(range(token_id, token_id + token_num)) + input_ids.append(self.image_end_id) + origin_ids = origin_ids[start_idx + 2 :] + image_id += 1 + else: + raise ValueError("image token error") + except ValueError: + break + input_ids.extend(origin_ids) + return input_ids + + def _build_t2i_query(self, msg, thinking_content=""): + template = self.conversation_module.get_conv_template(self.template) + template.append_message(template.roles[0], msg) + template.append_message(template.roles[1], None) + return template.get_prompt() + thinking_content + IMG_START_TOKEN + + def fix_prompt(self, prompt: str, img_len: int): + prompt_img_len = prompt.count(IMG_TOKEN) + assert prompt_img_len <= img_len, f"not enough images provided, need {prompt_img_len}, given {img_len}" + if prompt_img_len < img_len: + return f"{IMG_TOKEN}\n" * (img_len - prompt_img_len) + prompt + return prompt + + def get_query_for_it2i(self, prompt: str): + image_len = prompt.count(IMG_TOKEN) + query_condition = self._build_t2i_query(prompt, thinking_content="\n\n\n\n") + query_text_uncondition = self._build_t2i_query(IMG_TOKEN * image_len) + question_img_uncondition = self._build_t2i_query("") + return query_condition, query_text_uncondition, question_img_uncondition + + def get_query_for_t2i(self, prompt): + query_condition = self._build_t2i_query( + f"Please generate an image based on the following description: {prompt}", + thinking_content="\n\n\n\n", + ) + query_uncondition = self._build_t2i_query(f"") + return query_condition, query_uncondition + + +@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe")) +class NeoTpMOEPartModel(Qwen3MOEModel): + + pre_layer_infer_class = LlamaMultimodalPreLayerInfer + transformer_layer_infer_class = NeoChatMOETransformerLayerInfer + + pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight + transformer_weight_class = NeoChatMOETransformerLayerWeight + + infer_state_class = NeoChatInferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_inferstate_cls(self): + pass + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + all_config = json.load(json_file) + self.config = all_config["llm_config"] + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + return diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py new file mode 100644 index 0000000000..a516a99f64 --- /dev/null +++ b/lightllm/models/neo_chat_moe/neo_visual.py @@ -0,0 +1,281 @@ +import os +import torch +import torch.nn.functional as F +from PIL import Image +from typing import List +from io import BytesIO +import torch.nn as nn +from transformers.activations import ACT2FN +from safetensors import safe_open +from lightllm.server.multimodal_params import ImageItem +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from lightllm.models.neo_chat_moe.vision_process import load_image_native +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data + + +def apply_rotary_emb_1d( + x: torch.Tensor, + cos_cached: torch.Tensor, + sin_cached: torch.Tensor, + positions: torch.Tensor, +): + """对输入张量的一部分应用1D RoPE。""" + # x: (..., seq_len, dim_part) + # positions: (..., seq_len) + # cos_cached: (max_pos, dim_part / 2) + cos_cached = cos_cached.to(device=positions.device) + sin_cached = sin_cached.to(device=positions.device) + + cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) + sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) + + x1 = x[..., 0::2] + x2 = x[..., 1::2] + + rotated_x1 = x1 * cos - x2 * sin + rotated_x2 = x1 * sin + x2 * cos + + x_rotated = torch.empty_like(x) + x_rotated[..., 0::2] = rotated_x1 + x_rotated[..., 1::2] = rotated_x2 + return x_rotated + + +def apply_2d_rotary_pos_emb( + x: torch.Tensor, + cos_cached_x: torch.Tensor, + sin_cached_x: torch.Tensor, + cos_cached_y: torch.Tensor, + sin_cached_y: torch.Tensor, + abs_positions_x: torch.Tensor, + abs_positions_y: torch.Tensor, +): + """应用2D RoPE到输入张量x。""" + dim = x.shape[-1] + dim_half = dim // 2 + + # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 + # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) + x_part_1 = x[..., :dim_half] + x_part_2 = x[..., dim_half:] + + # 将与 abs_positions_x 相关的旋转应用于 x_part_1 + rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) + # 将与 abs_positions_y 相关的旋转应用于 x_part_2 + rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) + + # 将它们重新拼接起来。确保顺序与你分割时一致。 + return torch.cat((rotated_part_1, rotated_part_2), dim=-1) + + +def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): + """ + Compute patch coordinates (x, y) + + Args: + grid_hw: (B, 2) tensor representing (H, W) per image + """ + device = grid_hw.device + B = grid_hw.shape[0] + + # Get the number of patches per image + H = grid_hw[:, 0] + W = grid_hw[:, 1] + N = H * W + N_total = N.sum() + + # Create the batch index for each patch (B x patch count) + patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) + + # Generate intra-image patch index (row-major order) + patch_id_within_image = torch.arange(N_total, device=device) + patch_id_within_image = ( + patch_id_within_image + - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] + ) + + # Get H/W for each patch according to its image + W_per_patch = W[patch_to_sample] + abs_x = patch_id_within_image % W_per_patch + abs_y = patch_id_within_image // W_per_patch + + return abs_x, abs_y + + +class NeoVisionTransformerPretrainedModel(nn.Module): + def __init__( + self, + kvargs, + hidden_size: int = 1024, + llm_hidden_size: int = 2048, + downsample_ratio: float = 0.5, + patch_size: int = 16, + num_channels: int = 3, + max_position_embeddings_vision: int = 10000, + rope_theta_vision: float = 10000.0, + min_pixels: int = 65536, + max_pixels: int = 2408448, + **kwargs, + ): + super().__init__() + self.weight_dir = kvargs["weight_dir"] + self.data_type = kvargs.get("data_type", "bfloat16") + self.embed_dim = hidden_size + self.llm_hidden_size = llm_hidden_size + self.patch_size = patch_size + self.num_channels = num_channels + self.downsample_ratio = downsample_ratio + self.downsample_factor = int(1 / downsample_ratio) + self.max_position_embeddings_vision = max_position_embeddings_vision + self.rope_theta_vision = rope_theta_vision + self.rope_dim_part = self.embed_dim // 2 + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + self.patch_embedding = nn.Conv2d( + in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size + ) + + self.dense_embedding = nn.Conv2d( + in_channels=self.embed_dim, + out_channels=self.llm_hidden_size, + kernel_size=self.downsample_factor, + stride=self.downsample_factor, + ) + self.gelu = nn.GELU() + + self.repe_dim_part = self.embed_dim // 2 + self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos() + self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos() + self._init_datatype() + + def _init_datatype(self): + if isinstance(self.data_type, torch.dtype): + return + if self.data_type in ["fp16", "float16"]: + self.data_type = torch.float16 + elif self.data_type in ["bf16", "bfloat16"]: + self.data_type = torch.bfloat16 + elif self.data_type in ["fp32", "float32"]: + self.data_type = torch.float32 + else: + raise ValueError(f"Unsupport datatype {self.data_type}!") + return + + def load_model(self, weight_dir): + bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] + if bin_weight_files: + weight_dict = {} + for file_ in bin_weight_files: + f = torch.load(os.path.join(weight_dir, file_), "cpu") + for k, v in f.items(): + if "vision_model" in k and "fm_modules" not in k: + weight_dict[k[len("vision_model.embeddings.") :]] = v + else: + hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")] + weight_dict = {} + for file_ in hf_weight_files: + f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu") + for k in f.keys(): + if "vision_model" in k and "fm_modules" not in k: + weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k) + self.load_state_dict(weight_dict) + + def precompute_rope_freqs_sincos(self): + inv_freq = 1.0 / ( + self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part) + ) + t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq) + freqs = torch.outer(t, inv_freq) + return torch.cos(freqs), torch.sin(freqs) + + def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): + """ + Apply 2D Rotary Position Embedding to the patch embeddings. + """ + abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) + embeddings = apply_2d_rotary_pos_emb( + patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 + self.cos_x, + self.sin_x, + self.cos_y, + self.sin_y, + abs_pos_x, + abs_pos_y, + ).to(self.patch_embedding.weight.dtype) + return embeddings + + def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor: + pixel_values = pixel_values.view( + -1, + 3, + self.patch_size, + self.patch_size, + ) + patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) + patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) + assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[ + 0 + ], "Grid size and patch embeds size mismatch." + + patches_list = [] + cur_position = 0 + for i in range(grid_hw.shape[0]): + h, w = grid_hw[i] + patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) + patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) + patches_per_img = patches_per_img.permute(0, 2, 3, 1) + patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) + cur_position += h * w + + embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) + assert cur_position == patch_embeds.shape[0] + assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) + + return embeddings + + def encode(self, images: List[ImageItem]): + img_tensors = [] + valid_ids = [] + valid_id = 0 + img_grids = [] + uuids = [] + + for i, img in enumerate(images): + if isinstance(img, ImageItem): + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = Image.open(BytesIO(image_data)) + # a = img.extra_params["min_pixels"] + # b = img.extra_params["max_pixels"] + # print(f"self.min_pixels is {a} ,max_pixelx is {b}") + pixel_values, image_grid_hw = load_image_native( + image_data, + patch_size=self.patch_size, + downsample_ratio=self.downsample_ratio, + min_pixels=img.extra_params["min_pixels"], + max_pixels=img.extra_params["max_pixels"], + ) + img_tensors.append(pixel_values) + img_grids.append(image_grid_hw) + else: + raise Exception("Unsupport input types: {} for {}".format(type(img), img)) + + # must devide merge_length + cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2)) + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + imgs = torch.cat(img_tensors, dim=0) + grid_hw = torch.cat(img_grids, dim=0) + + pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True) + image_grid_hw = grid_hw.to("cuda", non_blocking=True) + + all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw) + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/neo_chat_moe/triton_kernel/__init__.py b/lightllm/models/neo_chat_moe/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py new file mode 100644 index 0000000000..74ff82cae4 --- /dev/null +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -0,0 +1,430 @@ +import math +import torch +import triton +import triton.language as tl + +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] + B_Start_Loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + b_image_token_tag, + H: tl.constexpr, + QK_HEAD_DIM: tl.constexpr, + V_HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + cur_bh = tl.program_id(1) + cur_batch = cur_bh // H + cur_head = cur_bh % H + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + total_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = total_len - prompt_cache_len # NEW len + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + block_start_loc = BLOCK_M * start_m + if block_start_loc >= cur_batch_seq_len: + return + + offs_n = tl.arange(0, BLOCK_N) + offs_d_qk = tl.arange(0, QK_HEAD_DIM) + offs_d_v = tl.arange(0, V_HEAD_DIM) + offs_m = block_start_loc + tl.arange(0, BLOCK_M) + + # Q pointers + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d_qk[None, :] * stride_qd + ) + + q_valid = offs_m < cur_batch_seq_len + q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) + + # online softmax state + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32) + block_end_loc = total_len + + # absolute q positions in the request + q_pos = prompt_cache_len + offs_m # [M] + q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False) + + for start_n in range(0, block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + k_pos = start_n + offs_n # [N] + k_valid = k_pos < block_end_loc + + # map logical pos -> mem_index (for K/V) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_valid, + other=0, + ).to(tl.int64) + + # load K + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d_qk[:, None] * stride_kd + k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + + # mask: causal OR same gid (only possible inside NEW part) + mask = (q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None] + qk = tl.where(mask, qk * sm_scale, -1.0e8) + + # online softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + # load V + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d_v[None, :] * stride_vd + v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) + + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d_v[None, :] * stride_od + ) + tl.store(Out + off_o, acc, mask=q_valid[:, None]) + + +@torch.no_grad() +def context_attention_fwd_neo( + q, + k, + v, + o, + position_ids, # 1D packed like q (only NEW tokens) + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, + b_image_token_tag, +): + # minimal safety: position_ids must cover packed q rows + assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) + + BLOCK_M = 128 if not is_tesla() else 64 + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128, 256} + sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 + + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + + BLOCK_N = BLOCK_M + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + position_ids, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + b_image_token_tag=b_image_token_tag, + H=head, + QK_HEAD_DIM=Lk, + V_HEAD_DIM=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + num_stages=num_stages, + ) + + +def reference_attention( + q, + k, + v, + position_ids_q, # 1D packed like q (only NEW tokens) + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, +): + device = q.device + dtype = q.dtype + sum_q, Hq, D = q.shape + Hk = k.shape[1] + kv_group_num = Hq // Hk + + batch = b_seq_len.shape[0] + out = torch.empty_like(q) + scale = 1.0 / math.sqrt(D) + + for b in range(batch): + req = int(b_req_idx[b].item()) + total_len = int(b_seq_len[b].item()) + prompt_len = int(b_prompt_cache_len[b].item()) + new_len = total_len - prompt_len + + q_start = int(b_start_loc[b].item()) + q_blk = q[q_start : q_start + new_len] # [M, Hq, D] + gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] + + # gather K/V for full request by logical pos -> mem_index + token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] + k_blk = k[token_locs] # [L, Hk, D] + v_blk = v[token_locs] # [L, Hk, D] + + # expand kv heads to q heads (GQA) + k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D] + + # positions + q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M] + k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L] + + # build allow mask: + # causal always + allow = k_pos[None, :] <= q_pos[:, None] + + # full-attn only inside NEW part by gid + # compare only when k_pos in NEW + k_in_new = k_pos >= prompt_len + k_rel = (k_pos - prompt_len).clamp_min(0) # [L] + # map k_rel to gid_new, but only valid where k_in_new + k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) + k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new + k_gid[k_in_new] = gid_new[k_rel[k_in_new]] + + allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) + + # scores: [Hq, M, L] + q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] + k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L] + scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L] + + neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) + scores = torch.where(allow[None, :, :], scores, neg) + + p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L] + v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D] + out_hq = torch.matmul(p, v_t) # [Hq, M, D] + out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D] + + out[q_start : q_start + new_len] = out_blk + + return out + + +def make_test_case( + device="cuda", + dtype=torch.float16, + batch=3, + Hq=8, + Hk=4, + D=64, + seed=0, + base_index=50000, +): + torch.manual_seed(seed) + + # prompt (cached) len and new len + prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device) + new_lens = torch.randint(low=1, high=8, size=(batch,), device=device) + total_lens = (prompt_lens + new_lens).to(torch.int32) + + max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) + + # packed q start + b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) + cur = 0 + for b in range(batch): + b_start_loc[b] = cur + cur += int(new_lens[b].item()) + sum_q = cur + + b_seq_len = total_lens + b_prompt_cache_len = prompt_lens.to(torch.int32) + + # one req per batch + num_req = batch + b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) + + # global KV space large, indices not small + sum_kv = int(total_lens.sum().item()) + kv_size = base_index + sum_kv + 1024 + pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index + + # Req_to_tokens [num_req, max_total_len] + req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32) + p = 0 + for r in range(num_req): + L = int(total_lens[r].item()) + req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) + p += L + + # position_ids_q: only NEW tokens, packed like q + position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) + for b in range(batch): + M = int(new_lens[b].item()) + start = int(b_start_loc[b].item()) + + gid = torch.arange(M, device=device, dtype=torch.int32) + + # make one repeated block inside NEW part to simulate image tokens + if M >= 4 and torch.rand((), device=device).item() > 0.3: + s = int(torch.randint(0, M - 2, (1,), device=device).item()) + e = min(M, s + 3) + gid[s:e] = gid[s] + + position_ids_q[start : start + M] = gid + + q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) + k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) + + return ( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) + + +def check_once(device="cuda", dtype=torch.float16, seed=0): + ( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) = make_test_case(device=device, dtype=dtype, seed=seed) + + context_attention_fwd_neo( + q, + k, + v, + o, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + ) + + ref = reference_attention( + q, + k, + v, + position_ids_q, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + ) + + diff = (o - ref).abs() + max_abs = diff.max().item() + denom = ref.abs().max().item() + 1e-6 + max_rel = max_abs / denom + + print(f"seed={seed}, dtype={dtype}") + print(f"max_abs_error = {max_abs:.6e}") + print(f"max_rel_error = {max_rel:.6e}") + print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2)) + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA, skip.") + else: + torch.cuda.synchronize() + check_once(dtype=torch.bfloat16, seed=0) + check_once(dtype=torch.bfloat16, seed=1) + check_once(dtype=torch.bfloat16, seed=2) diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py new file mode 100644 index 0000000000..1a3d4af73b --- /dev/null +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -0,0 +1,191 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_thwd_stride0: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + position_ids_stride0: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, + b_image_token_tag: torch.Tensor, + BLOCK_SIZE: tl.constexpr, +) -> torch.Tensor: + cur_batch = tl.program_id(0) + cache_len = tl.load(b_ready_cache_len + cur_batch) + q_seq_len = tl.load(b_q_seq_len + cur_batch) + image_num = tl.load(b_image_nums + cur_batch) + image_start_num = tl.load(b_image_start_num + cur_batch) + start_loc = tl.load(b_start_loc + cur_batch) + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_start_idx = start_loc + local_image_start_idx - cache_len + image_len = tl.load(b_image_len + image_start_num + i) + # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) + image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) + for j in range(0, image_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + # 目前没考虑视频,所以t 恒为 0 + t_pos = local_image_start_idx + off * 0 + h_pos = off // image_w + w_pos = off % image_w + tl.store( + b_image_token_tag + off + image_start_idx, + True, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + off + image_start_idx, + t_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 + off + image_start_idx, + h_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + tl.store( + position_ids + position_ids_stride0 * 2 + off + image_start_idx, + w_pos, + mask=(off < image_len) + & (off + local_image_start_idx - cache_len < q_seq_len) + & (local_image_start_idx - cache_len + off >= 0), + ) + + for i in range(image_num): + local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) + image_len = tl.load(b_image_len + image_start_num + i) + image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3) + image_end = local_image_start_idx + image_len - cache_len + text_start = tl.maximum(0, image_end) + for j in range(text_start, q_seq_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta + h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0) + w_pos = tl.load( + position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0 + ) + tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len)) + tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len)) + return + + +def get_neo_position_triton( + b_image_start_idx: torch.Tensor, + b_image_thwd: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_image_len: torch.Tensor, + position_ids: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_start_loc: torch.Tensor, + b_image_token_tag: torch.Tensor, +) -> torch.Tensor: + + batch_size = b_q_seq_len.shape[0] + assert batch_size == b_image_nums.shape[0] + grid = (batch_size,) + BLOCK_SIZE = 64 + _get_neo_position_triton[grid]( + b_image_start_idx=b_image_start_idx, + b_image_thwd=b_image_thwd, + b_image_thwd_stride0=b_image_thwd.stride(0), + b_image_nums=b_image_nums, + b_image_start_num=b_image_start_num, + b_image_len=b_image_len, + position_ids=position_ids, + position_ids_stride0=position_ids.stride(0), + b_ready_cache_len=b_ready_cache_len, + b_q_seq_len=b_q_seq_len, + b_start_loc=b_start_loc, + b_image_token_tag=b_image_token_tag, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +def test(): + b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda") + b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda") + b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda") + b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda") + position_ids = ( + torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + .unsqueeze(0) + .expand(3, -1) + .contiguous() + ) + b_image_token_tag = torch.zeros([position_ids.size(1)], dtype=torch.bool, device="cuda") + position_ids[1:].zero_() + b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") + b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") + b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") + get_neo_position_triton( + b_image_start_idx, + b_image_thwd, + b_image_nums, + b_image_start_num, + b_image_len, + position_ids, + b_ready_cache_len, + b_q_seq_len, + b_start_loc, + b_image_token_tag, + ) + + print(b_image_token_tag) + print(position_ids) + # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) + + # position_ids = ( + # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda") + # .unsqueeze(0) + # .expand(3, -1) + # .contiguous() + # ) + # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda") + # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda") + # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda") + + # get_neo_position_triton( + # b_image_start_idx, + # b_image_thwd, + # b_image_nums, + # b_image_start_num, + # b_image_len, + # position_ids, + # b_ready_cache_len, + # b_q_seq_len, + # b_start_loc, + # ) + + # print(f"old_value:\n{old_value}") + # print(f"position_ids:\n{position_ids}") + # assert torch.equal(old_value, position_ids) + + """ + tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8], + [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8], + [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]], + device='cuda:0', dtype=torch.int32) + """ + + +if __name__ == "__main__": + test() diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py new file mode 100644 index 0000000000..fbd57a5e9c --- /dev/null +++ b/lightllm/models/neo_chat_moe/vision_process.py @@ -0,0 +1,141 @@ +import re +import math +import torch +import string +import numpy as np +import pandas as pd +from PIL import Image +import torch.distributed as dist +import torchvision.transforms as T + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60 +def smart_resize( + height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304 +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, floor_by_factor(height / beta, factor)) + w_bar = max(factor, floor_by_factor(width / beta, factor)) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs): + width, height = image.size + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + + return image + + +def preprocess_pixel_values(pixel_values, patch_size=16): + c, h, w = pixel_values.shape + grid_h = h // patch_size + grid_w = w // patch_size + + flatten_pixel_values = ( + pixel_values.view(c, grid_h, patch_size, grid_w, patch_size) + .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size] + .reshape(grid_h * grid_w, c * patch_size ** 2) + ) + + grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device) + + return flatten_pixel_values, grid_hw + + +def get_contrasting_background(image): + """ + Calculate the color (white or black) that is different from the average foreground color + to use as the background color + """ + image_np = np.array(image) + if (image_np[:, :, 3] == 0).any(): + non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0] + if non_transparent_pixels.size == 0: + return None + pixel_mean = non_transparent_pixels.mean() + contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255) + return contrasting_color + else: + return None + + +def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False): + """ + Load and preprocess an image file, converting it to RGB mode, + resizing, normalizing, and optionally adding a thumbnail version. + """ + if image.mode == "RGBA": + bg_color = get_contrasting_background(image) + if bg_color: + background = Image.new("RGB", image.size, bg_color) + background.paste(image, mask=image.split()[3]) + image = background.convert("RGB") + else: + image = image.convert("RGB") + else: + image = image.convert("RGB") + + if upscale: + image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR) + + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.ToTensor(), + T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), + ] + ) + + new_image = dynamic_preprocess_native_resolution( + image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels + ) + pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size) + + # print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})") + + return pixel_values, grid_hw diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 6355ac2dbf..a1f71cf02a 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -20,6 +20,7 @@ from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel +from lightllm.models.neo_chat_moe.neo_visual import NeoVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry @@ -90,6 +91,8 @@ def exposed_init_model(self, kvargs): .eval() .bfloat16() ) + elif self.model_type == "neo_chat": + self.model = NeoVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() else: raise Exception(f"can not support {self.model_type} now") From 0272ec39aeb62557917a01ebcc299c30299102bf Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 27 Mar 2026 07:59:37 +0000 Subject: [PATCH 02/41] neo rope --- lightllm/models/llama/model.py | 98 ++++++++++--------- .../layer_infer/transformer_layer_infer.py | 1 + 2 files changed, 52 insertions(+), 47 deletions(-) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 1c0277e59b..20d6cad743 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -73,13 +73,20 @@ def _init_custom(self): """ rope_scaling = self.config.get("rope_scaling", None) if rope_scaling is None: - scaling_type = "default" + self._init_to_get_rotary() elif "rope_type" in rope_scaling: scaling_type = rope_scaling["rope_type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) elif "type" in rope_scaling: scaling_type = rope_scaling["type"] + self._init_rotary_by_scaling_type(scaling_type, rope_scaling) else: raise ValueError(f"Unknown RoPE scaling format {rope_scaling}") + if "rope_theta_hw" in self.config: + self._init_to_get_hw_rotary() + super()._init_custom() + + def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling): if scaling_type == "default" or "mrope_section" in rope_scaling: self._init_to_get_rotary() elif scaling_type == "yarn": @@ -94,9 +101,6 @@ def _init_custom(self): self._init_to_get_rotary() else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - if "rope_theta_hw" in self.config: - self._init_to_get_hw_rotary() - return def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) @@ -106,7 +110,6 @@ def _init_to_get_rotary(self, default_base=10000): rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) base = self.config.get("rope_theta", float(default_base)) - if "max_sequence_length" in self.config: max_seq_len = self.config["max_sequence_length"] else: @@ -126,9 +129,10 @@ def _init_to_get_rotary(self, default_base=10000): except: pass - inv_freq = 1.0 / ( + full_inv_freq = 1.0 / ( base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) ) + inv_freq = full_inv_freq[::2] # for neo t = ( torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) / rope_scaling_factor @@ -139,6 +143,47 @@ def _init_to_get_rotary(self, default_base=10000): self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return + def _init_to_get_hw_rotary(self, default_base=10000): + partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) + if self.config.get("rope_scaling", {}) is None: + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) + + base = self.config.get("rope_theta_hw", float(default_base)) + if "max_sequence_length" in self.config: + max_seq_len = self.config["max_sequence_length"] + else: + max_position_embeddings = self.config.get( + "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 + ) + max_seq_len = max_position_embeddings * rope_scaling_factor + + # NTK + try: + ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula + except: + pass + + full_inv_freq = 1.0 / ( + base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) + ) + inv_freq = full_inv_freq[::2] + t = ( + torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) + / rope_scaling_factor + ) + freqs = torch.outer(t, inv_freq) + + self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() + return + def _init_to_get_dynamic_ntk_rotary(self): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) max_position_embeddings = self.config.get("max_position_embeddings", 2048) @@ -301,44 +346,3 @@ def _init_to_get_llama3_rotary(self, default_base=10000): self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() return - - def _init_to_get_hw_rotary(self, default_base=10000): - partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2) - if self.config.get("rope_scaling", {}) is None: - rope_scaling_factor = 1.0 - else: - rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) - - base = self.config.get("rope_theta_hw", float(default_base)) - if "max_sequence_length" in self.config: - max_seq_len = self.config["max_sequence_length"] - else: - max_position_embeddings = self.config.get( - "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384 - ) - max_seq_len = max_position_embeddings * rope_scaling_factor - - # NTK - try: - ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) - assert ntk_alpha >= 1 - if ntk_alpha > 1: - logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") - max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula - except: - pass - - full_inv_freq = 1.0 / ( - base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) - ) - inv_freq = full_inv_freq[::2] - t = ( - torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32) - / rope_scaling_factor - ) - freqs = torch.outer(t, inv_freq) - - self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda() - self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda() - return diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index ec181a0b8d..4517b5688a 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -42,6 +42,7 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC q_hw_2d = q_hw.reshape(q.shape[0], -1) k_t_2d = k_t.reshape(k.shape[0], -1) k_hw_2d = k_hw.reshape(k.shape[0], -1) + layer_weight.qk_norm_weight_(q_t_2d, k_t_2d, eps=self.eps_) layer_weight.qk_hw_norm_weight_(q_hw_2d, k_hw_2d, eps=self.eps_) From 10f760d30d5a0ce81fd9cb34de5be5ff0eafb725 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 23 Mar 2026 08:35:51 +0000 Subject: [PATCH 03/41] support x2i. --- .../triton_kernel/kv_cache_offload.py | 305 ++++++++++++++++++ lightllm/server/api_cli.py | 5 + lightllm/server/api_http.py | 17 + lightllm/server/api_lightllm.py | 16 + lightllm/server/api_start.py | 24 +- lightllm/server/core/objs/__init__.py | 1 + lightllm/server/core/objs/req.py | 7 +- lightllm/server/core/objs/sampling_params.py | 6 + lightllm/server/core/objs/start_args_type.py | 3 + .../core/objs/token_chunck_hash_list.py | 11 + lightllm/server/core/objs/x2i_params.py | 80 +++++ lightllm/server/httpserver/manager.py | 124 +++++++ .../server/router/model_infer/infer_batch.py | 3 + .../model_infer/mode_backend/base_backend.py | 11 + .../model_infer/mode_backend/past_kv_cache.py | 153 +++++++++ lightllm/server/x2i_server/__init__.py | 0 lightllm/server/x2i_server/manager.py | 141 ++++++++ .../server/x2i_server/past_kv_cache_client.py | 138 ++++++++ lightllm/utils/kv_cache_utils.py | 2 +- .../kv_trans_kernel/test_kv_trans_from_gpu.py | 212 ++++++++++++ 20 files changed, 1254 insertions(+), 5 deletions(-) create mode 100644 lightllm/server/core/objs/x2i_params.py create mode 100644 lightllm/server/router/model_infer/mode_backend/past_kv_cache.py create mode 100644 lightllm/server/x2i_server/__init__.py create mode 100644 lightllm/server/x2i_server/manager.py create mode 100644 lightllm/server/x2i_server/past_kv_cache_client.py create mode 100644 unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py diff --git a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py index 0fdc43ab9f..686c0c1790 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py +++ b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py @@ -704,3 +704,308 @@ def load_cpu_kv_to_gpu( num_stages=1, ) return + + + +@triton.jit +def _offload_gpu_kv_to_cpu_for_x2i( + token_indexes_ptr, + gpu_kv_cache_ptr, + gpu_stride0, + gpu_stride1, + gpu_stride2, + gpu_kv_cache_scale_ptr, + gpu_scale_stride0, + gpu_scale_stride1, + gpu_scale_stride2, + cpu_kv_cache_ptr, + cpu_stride0, + cpu_stride1, + cpu_stride2, + cpu_stride3, + cpu_kv_cache_scale_ptr, + cpu_scale_stride0, + cpu_scale_stride1, + cpu_scale_stride2, + cpu_scale_stride3, + page_indexes_ptr, + layer_num, + head_dim, + scale_head_dim, + block_num, + token_num, + cpu_k_start_head_index: tl.constexpr, + cpu_k_head_num: tl.constexpr, + gpu_k_start_head_index: tl.constexpr, + gpu_k_head_num: tl.constexpr, + cpu_v_start_head_index: tl.constexpr, + cpu_v_head_num: tl.constexpr, + gpu_v_start_head_index: tl.constexpr, + gpu_v_head_num: tl.constexpr, + BLOCK_HEAD_DIM: tl.constexpr, + TOKEN_BLOCK: tl.constexpr, + HAS_SCALE: tl.constexpr, +): + block_start_index = tl.program_id(0) + block_split_size = tl.num_programs(axis=0) + + for block_index in tl.range(block_start_index, block_num, block_split_size): + cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64) + token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK) + token_range_mask = token_range < token_num + token_indexes = tl.load(token_indexes_ptr + token_range, mask=token_range_mask).to(tl.int64) + head_dim_range = tl.arange(0, BLOCK_HEAD_DIM) + head_dim_mask = head_dim_range < head_dim + scale_head_dim_mask = head_dim_range < scale_head_dim + + token_head_mask = token_range_mask[:, None] & head_dim_mask[None, :] + token_scale_mask = token_range_mask[:, None] & scale_head_dim_mask[None, :] + for layer_index in range(layer_num): + for k_head_index in range(gpu_k_head_num): + gpu_k_head_index = k_head_index + gpu_k_start_head_index + cpu_k_head_index = k_head_index + cpu_k_start_head_index + + gpu_ptr = ( + gpu_kv_cache_ptr + + layer_index.to(tl.int64) * gpu_stride0 + + token_indexes[:, None] * gpu_stride1 + + gpu_k_head_index.to(tl.int64) * gpu_stride2 + + head_dim_range[None, :] + ) + gpu_data = tl.load(gpu_ptr, mask=token_head_mask, other=0.0) + cpu_ptr = ( + cpu_kv_cache_ptr + + cpu_page_index * cpu_stride0 + + layer_index.to(tl.int64) * cpu_stride1 + + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2 + + cpu_k_head_index * cpu_stride3 + + head_dim_range[None, :] + ) + tl.store(cpu_ptr, gpu_data, mask=token_head_mask, cache_modifier=".wt") + + if HAS_SCALE: + gpu_scale_ptr = ( + gpu_kv_cache_scale_ptr + + layer_index.to(tl.int64) * gpu_scale_stride0 + + token_indexes[:, None] * gpu_scale_stride1 + + gpu_k_head_index.to(tl.int64) * gpu_scale_stride2 + + head_dim_range[None, :] + ) + gpu_scale_data = tl.load(gpu_scale_ptr, mask=token_scale_mask, other=0.0) + cpu_scale_ptr = ( + cpu_kv_cache_scale_ptr + + cpu_page_index * cpu_scale_stride0 + + layer_index.to(tl.int64) * cpu_scale_stride1 + + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_scale_stride2 + + cpu_k_head_index * cpu_scale_stride3 + + head_dim_range[None, :] + ) + tl.store(cpu_scale_ptr, gpu_scale_data, mask=token_scale_mask, cache_modifier=".wt",) + + + for v_head_index in range(gpu_v_head_num): + gpu_v_head_index = v_head_index + gpu_v_start_head_index + cpu_v_head_index = v_head_index + cpu_v_start_head_index + + gpu_ptr = ( + gpu_kv_cache_ptr + + layer_index.to(tl.int64) * gpu_stride0 + + token_indexes[:, None] * gpu_stride1 + + gpu_v_head_index.to(tl.int64) * gpu_stride2 + + head_dim_range[None, :] + ) + gpu_data = tl.load(gpu_ptr, mask=token_head_mask, other=0.0) + cpu_ptr = ( + cpu_kv_cache_ptr + + cpu_page_index * cpu_stride0 + + layer_index.to(tl.int64) * cpu_stride1 + + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2 + + cpu_v_head_index * cpu_stride3 + + head_dim_range[None, :] + ) + tl.store(cpu_ptr, gpu_data, mask=token_head_mask, cache_modifier=".wt") + + if HAS_SCALE: + gpu_scale_ptr = ( + gpu_kv_cache_scale_ptr + + layer_index.to(tl.int64) * gpu_scale_stride0 + + token_indexes[:, None] * gpu_scale_stride1 + + gpu_v_head_index.to(tl.int64) * gpu_scale_stride2 + + head_dim_range[None, :] + ) + gpu_scale_data = tl.load(gpu_scale_ptr, mask=token_scale_mask, other=0.0) + cpu_scale_ptr = ( + cpu_kv_cache_scale_ptr + + cpu_page_index * cpu_scale_stride0 + + layer_index.to(tl.int64) * cpu_scale_stride1 + + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_scale_stride2 + + cpu_v_head_index * cpu_scale_stride3 + + head_dim_range[None, :] + ) + tl.store(cpu_scale_ptr, gpu_scale_data, mask=token_scale_mask, cache_modifier=".wt",) + + + +@torch.no_grad() +def offload_gpu_kv_to_cpu_for_x2i( + token_indexes: torch.Tensor, + gpu_kv_cache: torch.Tensor, + gpu_kv_cache_scale: Optional[torch.Tensor], + cpu_kv_cache: torch.Tensor, + cpu_kv_cache_scale: Optional[torch.Tensor], + page_indexes: torch.Tensor, + tp_index: int, + tp_world_size: int, + grid_num: int, + _cache_data={}, +): + """ + Args: + token_indexes: (token_num, ) + gpu_kv_cache: (layer_num, token_num, head_num, head_dim) + cpu_kv_cache: (all_page_num, layer_num, token_block_size, head_num, head_dim) + page_indexes: (page_num,) + """ + + token_block_size = cpu_kv_cache.shape[2] + token_num = token_indexes.shape[0] + assert token_num <= page_indexes.shape[0] * token_block_size + + gpu_heads = gpu_kv_cache.shape[2] + gpu_head_dim = gpu_kv_cache.shape[3] + cpu_heads = cpu_kv_cache.shape[3] + cpu_head_dim = cpu_kv_cache.shape[4] + + assert gpu_head_dim == cpu_head_dim + assert gpu_kv_cache.shape[0] == cpu_kv_cache.shape[1] + + scale_size = (tp_world_size * gpu_heads) // cpu_heads + + if (gpu_heads, cpu_heads, tp_index, tp_world_size) in _cache_data: + need_offload, head_info_tuple = _cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)] + else: + if cpu_heads > 1: + assert (tp_world_size * gpu_heads) % cpu_heads == 0 + assert cpu_heads % 2 == 0 + cpu_heads_index = ( + torch.arange(0, cpu_heads, device="cpu", dtype=torch.int32) + .view(cpu_heads, 1) + .tile((1, scale_size)) + .view(2, tp_world_size, -1) + ) + k_cpu_heads_index = cpu_heads_index[0][tp_index] + v_cpu_heads_index = cpu_heads_index[1][tp_index] + + cpu_heads_index = torch.cat([k_cpu_heads_index, v_cpu_heads_index], dim=0).view(2, -1).numpy() + gpu_heads_index = torch.arange(0, gpu_heads, device="cpu", dtype=torch.int32).view(2, -1) + + need_offload = tp_index % scale_size == 0 + + cpu_k_start_head_index = int(cpu_heads_index[0, 0]) + cpu_k_head_num = len(cpu_heads_index[0]) + gpu_k_start_head_index = int(gpu_heads_index[0, 0]) + gpu_k_head_num = len(gpu_heads_index[0]) + assert cpu_k_head_num == gpu_k_head_num + cpu_v_start_head_index = int(cpu_heads_index[1, 0]) + cpu_v_head_num = len(cpu_heads_index[1]) + gpu_v_start_head_index = int(gpu_heads_index[1, 0]) + gpu_v_head_num = len(gpu_heads_index[1]) + assert cpu_v_head_num == gpu_v_head_num + + else: + assert gpu_heads == 1 + assert cpu_heads == 1 + + need_offload == tp_index == 0 + cpu_k_start_head_index = 0 + cpu_k_head_num = 1 + gpu_k_start_head_index = 0 + gpu_k_head_num = 1 + cpu_v_start_head_index = 0 + cpu_v_head_num = 0 + gpu_v_start_head_index = 0 + gpu_v_head_num = 0 + + head_info_tuple = ( + cpu_k_start_head_index, + cpu_k_head_num, + gpu_k_start_head_index, + gpu_k_head_num, + cpu_v_start_head_index, + cpu_v_head_num, + gpu_v_start_head_index, + gpu_v_head_num, + ) + _cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)] = (need_offload, head_info_tuple) + + if not need_offload: + return + + ( + cpu_k_start_head_index, + cpu_k_head_num, + gpu_k_start_head_index, + gpu_k_head_num, + cpu_v_start_head_index, + cpu_v_head_num, + gpu_v_start_head_index, + gpu_v_head_num, + ) = head_info_tuple + + assert token_block_size == triton.next_power_of_2(token_block_size) + + page_num = page_indexes.shape[0] + grid = (grid_num, ) + num_warps = 4 + num_stages = 1 + HAS_SCALE = gpu_kv_cache_scale is not None and cpu_kv_cache_scale is not None + if HAS_SCALE: + scale_head_dim = gpu_kv_cache_scale.shape[-1] + gpu_scale_stride = gpu_kv_cache_scale.stride() + cpu_scale_stride = cpu_kv_cache_scale.stride() + else: + scale_head_dim = 0 + gpu_scale_stride = [0 for _ in range(5)] + cpu_scale_stride = [0 for _ in range(5)] + + + _offload_gpu_kv_to_cpu_for_x2i[grid]( + token_indexes_ptr = token_indexes, + gpu_kv_cache_ptr = gpu_kv_cache, + gpu_stride0 = gpu_kv_cache.stride(0), + gpu_stride1 = gpu_kv_cache.stride(1), + gpu_stride2 = gpu_kv_cache.stride(2), + gpu_kv_cache_scale_ptr = gpu_kv_cache_scale, + gpu_scale_stride0=gpu_scale_stride[0], + gpu_scale_stride1=gpu_scale_stride[1], + gpu_scale_stride2=gpu_scale_stride[2], + cpu_kv_cache_ptr=cpu_kv_cache, + cpu_stride0=cpu_kv_cache.stride(0), + cpu_stride1=cpu_kv_cache.stride(1), + cpu_stride2=cpu_kv_cache.stride(2), + cpu_stride3=cpu_kv_cache.stride(3), + cpu_kv_cache_scale_ptr=cpu_kv_cache_scale, + cpu_scale_stride0=cpu_scale_stride[0], + cpu_scale_stride1=cpu_scale_stride[1], + cpu_scale_stride2=cpu_scale_stride[2], + cpu_scale_stride3=cpu_scale_stride[3], + page_indexes_ptr=page_indexes, + layer_num=gpu_kv_cache.shape[0], + head_dim=gpu_head_dim, + scale_head_dim=scale_head_dim, + block_num=page_num, + token_num=token_num, + cpu_k_start_head_index=cpu_k_start_head_index, + cpu_k_head_num=cpu_k_head_num, + gpu_k_start_head_index=gpu_k_start_head_index, + gpu_k_head_num=gpu_k_head_num, + cpu_v_start_head_index=cpu_v_start_head_index, + cpu_v_head_num=cpu_v_head_num, + gpu_v_start_head_index=gpu_v_start_head_index, + gpu_v_head_num=gpu_v_head_num, + BLOCK_HEAD_DIM=triton.next_power_of_2(gpu_head_dim), + TOKEN_BLOCK=token_block_size, + HAS_SCALE=HAS_SCALE, + num_warps=num_warps, + num_stages=num_stages, + ) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d32da8097c..06d8c57320 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -302,6 +302,11 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="if the model is a multimodal model, set to not load audio part model.", ) + parser.add_argument( + "--enable_multimodal_x2i", + action="store_true", + help="Whether or not to allow to generate images (requird --enable_multimodal)." + ) parser.add_argument( "--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service." ) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 230da5b369..4c366fb770 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -70,6 +70,7 @@ class G_Objs: args: StartArgs = None g_generate_func: Callable = None g_generate_stream_func: Callable = None + g_generate_image_func: Callable = None httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None shared_token_load: TokenLoad = None @@ -85,6 +86,10 @@ def set_args(self, args: StartArgs): self.g_generate_func = lightllm_generate self.g_generate_stream_func = lightllm_generate_stream + if args.enable_multimodal_x2i: + from .api_lightllm import lightllm_generate_image + self.g_generate_image_func = lightllm_generate_image + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::api_server") if args.run_mode == "pd_master": @@ -257,6 +262,18 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo resp = await completions_impl(request, raw_request) return resp +@app.post("/generate_image") +async def generate_image(request: Request) -> Response: + if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + return create_error_response( + HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" + ) + + try: + return await g_objs.g_generate_image_func(request, g_objs.httpserver_manager) + except Exception as e: + return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) + @app.get("/tokens") @app.post("/tokens") diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f54..b20670e47c 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -3,6 +3,7 @@ from fastapi import BackgroundTasks, Request from fastapi.responses import Response, StreamingResponse from lightllm.server.core.objs.sampling_params import SamplingParams +from lightllm.server.core.objs.x2i_params import X2IParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager import ujson as json @@ -150,3 +151,18 @@ async def stream_results() -> AsyncGenerator[bytes, None]: background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + + +async def lightllm_generate_image(request: Request, httpserver_manager: HttpServerManager) -> Response: + # 这个接口目前主要是给x2v gen用的,输入是文本,输出是图片特征 + request_dict = await request.json() + prompt = request_dict.pop("inputs") + generation_params_dict = request_dict["parameters"] + generation_params = X2IParams() + generation_params.init(**generation_params_dict) + multimodal_params_dict = request_dict.get("multimodal_params", {}) + multimodal_params = MultimodalParams(**multimodal_params_dict) + + results = await httpserver_manager.generate_image(prompt, generation_params, multimodal_params, request=request) + + return Response(content=json.dumps({"images": results}, ensure_ascii=False).encode("utf-8")) \ No newline at end of file diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 364f9ca281..dc347e4929 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -118,6 +118,8 @@ def normal_or_p_d_start(args): if not args.disable_shm_warning: check_recommended_shm_size(args) + if args.enable_multimodal_x2i: + args.multi_modal_x2i_cache_shm_id = uuid.uuid1().int % 123456789 assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] # 确保单机上多实列不冲突 @@ -248,7 +250,7 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports + num=12+ node_world_size + args.visual_dp * (args.visual_tp + 1), used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -262,8 +264,10 @@ def normal_or_p_d_start(args): metric_port, multi_level_kv_cache_port, pd_decode_rpyc_port, - ) = can_use_ports[0:10] - can_use_ports = can_use_ports[10:] + x2i_port, + http_server_port_for_x2i, + ) = can_use_ports[0:12] + can_use_ports = can_use_ports[12:] visual_model_tp_ports = [] visual_nccl_ports = [] @@ -288,6 +292,8 @@ def normal_or_p_d_start(args): args.metric_port = metric_port args.multi_level_kv_cache_port = multi_level_kv_cache_port args.visual_nccl_ports = visual_nccl_ports + args.x2i_port = x2i_port + args.http_server_port_for_x2i = http_server_port_for_x2i # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] # p d 分离模式下用于标识节点的id @@ -346,6 +352,18 @@ def normal_or_p_d_start(args): ], ) + if args.enable_multimodal_x2i: + from .x2i_server.manager import start_x2i_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_x2i_process, + ], + start_args=[ + (args,), + ], + ) + if args.enable_cpu_cache: from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py index ec2438de9d..06ee53be33 100644 --- a/lightllm/server/core/objs/__init__.py +++ b/lightllm/server/core/objs/__init__.py @@ -1,4 +1,5 @@ from .sampling_params import SamplingParams +from .x2i_params import X2IParams from .req import Req, FinishStatus from .shm_req_manager import ShmReqManager from .start_args_type import StartArgs diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 8905248bf8..3ca4b55c6a 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -6,7 +6,7 @@ from .sampling_params import SamplingParams from .out_token_circlequeue import CircularQueue from .shm_array import ShmArray -from .token_chunck_hash_list import TokenHashList, CpuCachePageList +from .token_chunck_hash_list import TokenHashList, CpuCachePageList, PastKVCachePageList from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.envs_utils import get_env_start_args @@ -122,6 +122,8 @@ class Req(ctypes.Structure): ("cpu_cache_match_page_indexes", CpuCachePageList), # 分块hash的块大小 ("cpu_cache_token_page_size", ctypes.c_int), + # 用于图片生成场景,记录请求对应的kv cache页面信息,供生成过程使用。 + ("past_kv_cache_page_indexes", PastKVCachePageList), ] def get_str(self): @@ -185,6 +187,9 @@ def init( self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size if get_env_start_args().enable_cpu_cache: self._fill_input_token_hash() + + if sample_param.img_gen_prefill: + self.past_kv_cache_page_indexes = PastKVCachePageList(self.input_len) return def post_init(self): diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 49b21c38fc..3cf5cb0887 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -322,6 +322,7 @@ class SamplingParams(ctypes.Structure): ("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True ("disable_prompt_cache", ctypes.c_bool), # whether to disable prompt cache ("seed", ctypes.c_int64), # random seed + ("img_gen_prefill", ctypes.c_bool), # whether to prefill for image generation, need return past key values back ] _do_sample: bool = False @@ -354,6 +355,8 @@ def init(self, tokenizer, **kwargs): self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS) self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False) + self.img_gen_prefill = kwargs.get("img_gen_prefill", False) + self.add_special_tokens = kwargs.get("add_special_tokens", True) self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) self.print_eos_token = kwargs.get("print_eos_token", False) @@ -404,6 +407,9 @@ def init(self, tokenizer, **kwargs): self.temperature = 1.0 self.top_k = 1 + if self.img_gen_prefill: + self.max_new_tokens = 1 + self.verify() @classmethod diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 37c022f3a3..b741847b4c 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -160,3 +160,6 @@ class StartArgs: metric_port: int = field(default=None) multinode_httpmanager_port: int = field(default=12345) multi_level_kv_cache_port: int = field(default=None) + x2i_port: int = field(default=None) + http_server_port_for_x2i: int = field(default=None) + enable_multimodal_x2i: bool = field(default=False) diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py index a79ff48a85..479c6c17ff 100644 --- a/lightllm/server/core/objs/token_chunck_hash_list.py +++ b/lightllm/server/core/objs/token_chunck_hash_list.py @@ -88,3 +88,14 @@ def clear(self): def get_all(self): return list(self.items[0 : self.size]) + + +class PastKVCachePageList(CpuCachePageList): + _pack_ = 4 + _fields_ = CpuCachePageList._fields_ +[ + ("token_len", ctypes.c_int), # 对应的token数量 + ] + + def __init__(self, token_len: int = 0): + super().__init__() + self.token_len = token_len diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py new file mode 100644 index 0000000000..66fcbe9afb --- /dev/null +++ b/lightllm/server/core/objs/x2i_params.py @@ -0,0 +1,80 @@ +import ctypes +from dataclasses import dataclass +from typing import Dict, List +from enum import IntEnum +from .token_chunck_hash_list import PastKVCachePageList + +class CfgNormType(IntEnum): + NONE = 0 + CFG_ZERO_STAR = 1 + GLOBAL = 2 + + +class X2IParams(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("width", ctypes.c_int), + ("height", ctypes.c_int), + ("steps", ctypes.c_int), + ("guidance_scale", ctypes.c_float), + ("image_guidance_scale", ctypes.c_float), + ("seed", ctypes.c_int), + ("num_images", ctypes.c_int), + ("cfg_norm", ctypes.c_int), + ("past_kvcache", PastKVCachePageList), + ("past_kvcache_text", PastKVCachePageList), + ("past_kvcache_img", PastKVCachePageList), + ("total_prompt_tokens", ctypes.c_int), + ("request_id", ctypes.c_int64), + ] + + _width: int = 512 + _height: int = 512 + _steps: int = 30 + _guidance_scale: float = 7.0 + _image_guidance_scale: float = 7.0 + _seed: int = 42 + _num_images: int = 1 + _cfg_norm: CfgNormType = CfgNormType.NONE + + def init(self, **kwargs): + def _get(key, default): + v = kwargs.get(key) + return v if v is not None else default + self.width = _get("width", X2IParams._width) + self.height = _get("height", X2IParams._height) + self.steps = _get("steps", X2IParams._steps) + self.guidance_scale = _get("guidance_scale", X2IParams._guidance_scale) + self.image_guidance_scale = _get("image_guidance_scale", X2IParams._image_guidance_scale) + self.seed = _get("seed", X2IParams._seed) + self.num_images = _get("num_images", X2IParams._num_images) + self.cfg_norm = _get("cfg_norm", X2IParams._cfg_norm) + self.past_kvcache = PastKVCachePageList() + self.past_kvcache_text = PastKVCachePageList() + self.past_kvcache_img = PastKVCachePageList() + self.total_prompt_tokens = 0 + self.request_id = 0 + + def update(self, past_kv: PastKVCachePageList, meta: Dict): + past_kv.token_len = meta.get("prompt_tokens") + past_kv.fill(meta.get("kv_cache_pages")) + self.total_prompt_tokens += past_kv.token_len + + def update_t2i(self, meta, meta_uncond): + self.update(self.past_kvcache, meta) + self.update(self.past_kvcache_text, meta_uncond) + + def update_it2i(self, meta, meta_text_uncond, meta_img_uncond): + self.update(self.past_kvcache, meta) + self.update(self.past_kvcache_text, meta_text_uncond) + self.update(self.past_kvcache_img, meta_img_uncond) + + +@dataclass +class X2IResponse: + request_id: int + images: List[bytes] + +@dataclass +class X2ICacheRelease: + request_id: int \ No newline at end of file diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e28e4c93ad..1c943e4937 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -10,6 +10,7 @@ import hashlib import datetime import pickle +import re from frozendict import frozendict asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -24,6 +25,7 @@ from .async_queue import AsyncQueue from lightllm.server.core.objs import Req, FinishStatus, StartArgs from lightllm.server.core.objs import SamplingParams +from lightllm.server.core.objs.x2i_params import X2IParams, X2ICacheRelease, X2IResponse from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE from lightllm.server.core.objs.io_objs import GroupReqObjs from lightllm.server.core.objs.shm_req_manager import ShmReqManager @@ -95,6 +97,15 @@ def __init__( self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH) self.send_to_multi_level_kv_cache.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") + if args.enable_multimodal_x2i: + from lightllm.server.x2i_server.past_kv_cache_client import PastKVCacheClient + self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=True, init_shm_data=False) + self.send_to_x2i = context.socket(zmq.PUSH) + self.send_to_x2i.connect(f"{args.zmq_mode}127.0.0.1:{args.x2i_port}") + self.recv_from_x2i = context.socket(zmq.PULL) + self.recv_from_x2i.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}") + self.req_id_to_x2i_reqs: Dict[int, X2IReqStatus] = {} + self.shm_req_manager = ShmReqManager() # recv from detokenization @@ -354,6 +365,15 @@ async def generate( self.tokenizer, chunked_prefill_size=self.args.chunked_prefill_size, ) + if sampling_params.img_gen_prefill: + # allocate pages, may block if cache is full, but it won't cause deadlock + # because the prefill process is designed to be sequential and the pages + # will be released after prefill. + kv_pages = self.past_kv_cache_client.allocate_pages( + req_obj.request_id, req_obj.input_len) + + req_obj.past_kv_cache_page_indexes.fill(kv_pages) + req_objs.append(req_obj) logger.debug( @@ -404,10 +424,82 @@ async def generate( # 进行回收。 if group_request_id not in self.req_id_to_out_inf: await self._release_multimodal_resources(multimodal_params) + + if sampling_params.img_gen_prefill: + # 预分配了 kv cache 的请求,在异常情况下需要主动释放 kv cache 资源 + self.past_kv_cache_client.free_pages_by_req_id(group_request_id) + await self.abort(group_request_id) + raise e return + + async def generate_image(self, prompt: str, generation_params: X2IParams, multimodal_params: MultimodalParams, request: Request): + generate_req_ids = [] + async def generation_wrapper(prompt, sample, multimodal, request): + async for sub_req_id, _, metadata, finish_status in self.generate( + prompt, sample, multimodal, request + ): + kv_cache_pages = self.past_kv_cache_client.get_pages_by_req_id(sub_req_id) + if kv_cache_pages is None: + raise Exception(f"kv_cache_pages is None for sub_req_id {sub_req_id}") + metadata["kv_cache_pages"] = kv_cache_pages + metadata["request_id"] = sub_req_id + metadata["finish_status"] = finish_status + generate_req_ids.append(sub_req_id) + return metadata + + try: + # 1. construct 3 or 2 images based on the multimodel_parmas + sample_params = SamplingParams() + sample_params.init(self.tokenizer, **{"img_gen_prefill": True}) + img_len = len(multimodal_params.images) + + if img_len > 0: + # call it2i + prompt_condition = f"Please generate an image based on the following instruction: {prompt}" + prompt_text_uncondition = "Please generate an image based on the following instruction: "+ '\n' * img_len + prompt_img_uncondition = "Please generate an image based on the following instruction: " + re.sub(r"\n?", "", prompt) + (con_gen, text_uncon_gen, img_uncon_gen) = await asyncio.gather(*[ + generation_wrapper(prompt_condition, sample_params, multimodal_params, request), + generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params, request), + generation_wrapper(prompt_img_uncondition, sample_params, MultimodalParams(), request)]) + generation_params.update_it2i(con_gen, text_uncon_gen, img_uncon_gen) + else: + # call t2i + prompt_condition = f"Please generate an image based on the following caption: {prompt}" + prompt_uncondition = f"Please generate an image based on the following caption: " + (con_gen, uncon_gen) = await asyncio.gather(*[ + generation_wrapper(prompt_condition, sample_params, multimodal_params, request), + generation_wrapper(prompt_uncondition, sample_params, multimodal_params, request)]) + generation_params.update_t2i(con_gen, uncon_gen) + # use the first reqeust id as the gen image request id + x2i_req_id = generate_req_ids[0] + generation_params.request_id = x2i_req_id + + req_status = X2IReqStatus(generation_params, generate_req_ids) + self.req_id_to_x2i_reqs[generation_params.request_id] = req_status + + # send generation_params to generation server for image generation + await self.send_to_x2i.send_pyobj(generation_params, protocol=pickle.HIGHEST_PROTOCOL) + + await req_status.event.wait() + + assert req_status.response is not None + + self.req_id_to_x2i_reqs.pop(x2i_req_id, None) + + return req_status.response.images + + except Exception as e: + logger.error(str(e)) + pass + + finally: + for req_id in generate_req_ids: + self.past_kv_cache_client.free_pages_by_req_id(req_id) + def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]: image_tokens = 0 audio_tokens = 0 @@ -739,6 +831,28 @@ async def recycle_resource_loop(self): ) return + async def loop_for_x2i(self): + + while True: + try: + recv_obj = await asyncio.wait_for(self.recv_from_x2i.recv_pyobj(), timeout=0.05) + + if isinstance(recv_obj, X2ICacheRelease): + status = self.req_id_to_x2i_reqs[recv_obj.request_id] + for req_id in status.req_ids: + self.past_kv_cache_client.free_pages_by_req_id(req_id) + + elif isinstance(recv_obj, X2IResponse): + status = self.req_id_to_x2i_reqs[recv_obj.request_id] + status.response = recv_obj + status.event.set() + + except asyncio.TimeoutError: + pass + except Exception as e: + logger.error(e) + + async def handle_loop(self): self.recycle_event = asyncio.Event() asyncio.create_task(self.recycle_resource_loop()) @@ -753,6 +867,9 @@ async def handle_loop(self): asyncio.create_task(pd_handle_loop(self)) + if self.args.enable_multimodal_x2i: + asyncio.create_task(self.loop_for_x2i()) + while True: try: await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) @@ -839,3 +956,10 @@ def can_release(self): if not req.can_release(): return False return True + +class X2IReqStatus: + def __init__(self, req_param: X2IParams, req_ids: List[int]): + self.req = X2IParams + self.req_ids = req_ids + self.event = asyncio.Event() + self.response: X2IResponse = None \ No newline at end of file diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 0a83b101be..0088c60d65 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -357,6 +357,9 @@ def __init__( # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED + # img gen req need copy kv to cpu + self.past_kv_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED + # mtp_step 用来记录一个请求 draft模型每步需要生成的token数量 # 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量 self.mtp_step: int = get_env_start_args().mtp_step diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 8b085c45ed..12e418824b 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -47,6 +47,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule +from .past_kv_cache import PastKVCacheModule class ModeBackend: @@ -147,6 +148,9 @@ def init_model(self, kvargs): if self.args.enable_multimodal: g_infer_context.init_cpu_embed_cache_client() + if self.args.enable_multimodal_x2i: + self.past_kv_cache_module = PastKVCacheModule(self) + model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir) model_kvargs = { @@ -551,6 +555,9 @@ def _get_classed_reqs( if self.args.enable_cpu_cache and len(g_infer_context.infer_req_ids) > 0: self.multi_level_cache_module.update_cpu_cache_task_states() + if self.args.enable_multimodal_x2i and len(g_infer_context.infer_req_ids) > 0: + self.past_kv_cache_module.update_past_kv_cache_task_states() + if req_ids is None: req_ids = g_infer_context.infer_req_ids @@ -648,6 +655,10 @@ def _get_classed_reqs( else: true_finished_reqs = finished_reqs + if self.args.enable_multimodal_x2i: + true_finished_reqs = self.past_kv_cache_module.offload_finished_reqs_to_past_kv_cache( + finished_reqs=true_finished_reqs) + g_infer_context.filter_reqs(finished_reqs=true_finished_reqs) g_infer_context.pause_reqs(wait_pause_reqs, is_master_in_dp=self.is_master_in_dp) diff --git a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py new file mode 100644 index 0000000000..de80da5d01 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py @@ -0,0 +1,153 @@ +import torch +import torch.distributed as dist +from dataclasses import dataclass +from typing import Optional, List, Deque +from collections import deque +from functools import lru_cache +from lightllm.server.x2i_server.past_kv_cache_client import PastKVCacheClient +from lightllm.server.router.model_infer.infer_batch import InferReq +from lightllm.server.router.model_infer.infer_batch import g_infer_context +from lightllm.utils.dist_utils import create_new_group_for_current_dp +from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu_for_x2i +from lightllm.server.core.objs.token_chunck_hash_list import LIGHTLLM_TOKEN_HASH_LIST_SIZE + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class TransTask: + req_obj: InferReq + sync_event: torch.cuda.Event + + +class PastKVCacheModule(object): + def __init__(self, backend): + from .base_backend import ModeBackend + self.backend: ModeBackend = backend + self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=False) + self.page_index_buffer = torch.empty((LIGHTLLM_TOKEN_HASH_LIST_SIZE * 2,), dtype=torch.int32, device="cuda") + self.past_kv_cache_task: Deque[TransTask] = deque() + self.sync_task_status_group = create_new_group_for_current_dp("gloo") + + @lru_cache() + def need_sync_compute_stream(self) -> bool: + """ + fa3 在 offload 和 load kv cache 的时候,需要等待计算流完成,否则可能会概率崩溃。 + """ + + model = self.backend.model + att_backends = [ + model.prefill_att_backend, + model.decode_att_backend, + model.prefill_att_backend1, + model.decode_att_backend1, + ] + for att_backend in att_backends: + if att_backend is not None and "fa3" in att_backend.__class__.__name__.lower(): + logger.info("PastKVCacheModule: need sync compute stream for fa3 backend.") + return True + logger.info("PastKVCacheModule: no need sync compute stream.") + return False + + + def offload_finished_reqs_to_past_kv_cache(self, finished_reqs: List[InferReq]) -> List[InferReq]: + """ + Offload the finished reqs to past kv cache, and return the truly finished reqs that can be freed in infer batch. + """ + true_finished_reqs = [] + for req in finished_reqs: + # filter out non-img-gen reqs + if not req.shm_req.sample_params.img_gen_prefill: + true_finished_reqs.append(req) + continue + + if req.past_kv_cache_task_status.is_finished(): + true_finished_reqs.append(req) + continue + + if req.past_kv_cache_task_status.is_running(): + continue + + assert req.past_kv_cache_task_status.is_not_started() + + if self.need_sync_compute_stream(): + g_infer_context.get_overlap_stream().synchronize() + + trans_task = self._start_kv_cache_offload(req=req) + assert trans_task is not None + self.past_kv_cache_task.append(trans_task) + + + return true_finished_reqs + + def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]: + + with torch.cuda.stream(g_infer_context.get_cpu_kv_cache_stream()): + page_indexes = torch.tensor(req.shm_req.past_kv_cache_page_indexes.get_all(), dtype=torch.int32, device='cpu', pin_memory=True) + num_tokens = req.shm_req.input_len + + assert req.cur_kv_len >= num_tokens + assert num_tokens <= len(page_indexes) * self.past_kv_cache_client.token_page_size + + cuda_page_indexes = self.page_index_buffer[:len(page_indexes)] + cuda_page_indexes.copy_(page_indexes) + + token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0: num_tokens] + mem_manager = self.backend.model.mem_manager + + + if hasattr(mem_manager, "scale_buffer") and mem_manager.scale_buffer is not None: + cpu_cache_meta = self.past_kv_cache_client.kv_cache_tensor_meta + cpu_kv_cache = self.past_kv_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0:cpu_cache_meta.head_dim] + cpu_kv_cache_scale = self.past_kv_cache_client.cpu_kv_cache_tensor[ + :, :, :, :, cpu_cache_meta.head_dim + ].view(mem_manager.scale_buffer.dtype) + gpu_kv_cache_scale = mem_manager.scale_buffer + else: + cpu_kv_cache = self.past_kv_cache_client.cpu_kv_cache_tensor + cpu_kv_cache_scale = None + gpu_kv_cache_scale = None + + grid_num = 16 + offload_gpu_kv_to_cpu_for_x2i( + token_indexes=token_indexes, + gpu_kv_cache=mem_manager.kv_buffer, + gpu_kv_cache_scale=gpu_kv_cache_scale, + cpu_kv_cache=cpu_kv_cache, + cpu_kv_cache_scale=cpu_kv_cache_scale, + page_indexes=cuda_page_indexes, + tp_index=self.backend.rank_in_dp, + tp_world_size=self.backend.dp_world_size, + grid_num=grid_num, + ) + sync_event = torch.cuda.Event() + sync_event.record() + req.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING + return TransTask( + req_obj=req, + sync_event=sync_event, + ) + + def update_past_kv_cache_task_states(self): + trans_ok_tasks = [] + while len(self.past_kv_cache_task) > 0: + task: TransTask = self.past_kv_cache_task.popleft() + if task.sync_event.query(): + trans_ok_tasks.append(task) + else: + self.past_kv_cache_task.appendleft(task) + break + + if len(trans_ok_tasks) == 0: + return + + ok_tasks_num = torch.tensor(len(trans_ok_tasks)) + dist.all_reduce(ok_tasks_num, op=dist.ReduceOp.MIN, group=self.sync_task_status_group) + + if ok_tasks_num.item() > 0: + finished, unfinished = trans_ok_tasks[:ok_tasks_num.item()], trans_ok_tasks[ok_tasks_num.item():] + self.past_kv_cache_task.extendleft(reversed(unfinished)) + for task in finished: + task.req_obj.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED diff --git a/lightllm/server/x2i_server/__init__.py b/lightllm/server/x2i_server/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py new file mode 100644 index 0000000000..3dc82be897 --- /dev/null +++ b/lightllm/server/x2i_server/manager.py @@ -0,0 +1,141 @@ +import zmq +import zmq.asyncio +import asyncio +import uvloop +import inspect +import setproctitle +import pickle +import torch +from typing import List +from lightllm.server.core.objs import StartArgs + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +from lightllm.utils.log_utils import init_logger +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease +from .past_kv_cache_client import PastKVCacheClient + +logger = init_logger(__name__) + +''' +manage a generation service, +1. start x2v pipelines +2. receive generation request from http_server. +3. call llm gen to obtain past key values +4. call x2v to generate images and pass the key values to it +5. return the generated images. +''' + +class X2IManager: + def __init__( + self, + args: StartArgs, + ): + context = zmq.Context(2) + self.args = args + + self.zmq_recv_socket = context.socket(zmq.PULL) + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.x2i_port}") + + self.send_to_httpserver = context.socket(zmq.PUSH) + self.send_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}") + + self.waiting_reqs: List[X2IParams] = [] + + from lightllm.utils.dist_utils import set_current_device_id + + set_current_device_id(torch.cuda.current_device()) + + self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True) + + async def wait_to_model_ready(self): + # from lightx2v import LightX2VPipeline + # self.gen_pipe = LightX2VPipeline( + # model_path = self.args.model_dir, + # model_cls = self.args.model_name, + # task="t2i" + # ) + # self.gen_pipe.create_generator( + # config_json = self.args.x2v_gen_model_config, + # ) + + pass + + async def loop_for_fwd(self): + while True: + try: + if len(self.waiting_reqs) == 0: + await asyncio.sleep(0.01) + continue + + x2i_param = self.waiting_reqs.pop(0) + + past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i( + x2i_param.past_kvcache.get_all(), x2i_param.past_kvcache.token_len + ) + + past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i( + x2i_param.past_kvcache_text.get_all(), x2i_param.past_kvcache_text.token_len + ) + is_t2i = x2i_param.past_kvcache_img.is_empty() + + logger.info(f"past kv cache shape: {past_kv_cache.shape}, past_kv_cache_text shape: {past_kv_cache_text.shape}") + + past_kv_cache_img = None + if not is_t2i: # t2i + past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i( + x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len + ) + + # release + self.send_to_httpserver.send_pyobj( + X2ICacheRelease(request_id=x2i_param.request_id), + protocol=pickle.HIGHEST_PROTOCOL) + + # call generate images + self.send_to_httpserver.send_pyobj(X2IResponse( + request_id=x2i_param.request_id, + images=[]), + protocol=pickle.HIGHEST_PROTOCOL) + + except Exception as e: + logger.error(e) + + + async def loop_for_netio_req(self): + while True: + try: + recv_req: X2IParams = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + self.waiting_reqs.append(recv_req) + + except zmq.ZMQError: + await asyncio.sleep(0.1) + + await asyncio.sleep(0.01) + +def start_x2i_process(args, pipe_writer): + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::x2i_server") + start_parent_check_thread() + try: + x2iserver = X2IManager(args=args,) + asyncio.run(x2iserver.wait_to_model_ready()) + except Exception as e: + logger.exception(str(e)) + x2iserver.clean_up() + raise e + + pipe_writer.send("init ok") + + def handle_exception(loop, context): + logger.exception(f"X2IServer Caught exception: {str(context)}") + + loop = asyncio.new_event_loop() + loop.set_exception_handler(handle_exception) + asyncio.set_event_loop(loop) + loop.create_task(x2iserver.loop_for_fwd()) + loop.run_until_complete(x2iserver.loop_for_netio_req()) + return diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py new file mode 100644 index 0000000000..4a95635d1e --- /dev/null +++ b/lightllm/server/x2i_server/past_kv_cache_client.py @@ -0,0 +1,138 @@ +import ctypes +import torch +import numpy as np +from threading import Lock, Condition +from dataclasses import dataclass +from lightllm.utils.envs_utils import get_env_start_args +from typing import List, Optional, Tuple +from lightllm.utils.log_utils import init_logger +from lightllm.utils.kv_cache_utils import ( + calcu_cpu_cache_meta, + create_shm_kv_cache_ptr, + attach_shm_kv_cache_ptr, + register_shm_ptr_to_pin, +) + +logger = init_logger(__name__) + + +@dataclass +class PastKVCacheItem: + req_id: int + token_len: int + page_indexes: List[int] + + +class PastKVCacheClient(object): + """ + This class is responsible for passing kv cache between generation server and model server, + and manage the shared memory for kv cache. + """ + + def __init__(self, only_create_meta_data: bool, init_shm_data: bool): + self.args = get_env_start_args() + # to do here need calcu from from settings. + self.kv_cache_tensor_meta = calcu_cpu_cache_meta() + self.page_num: int = self.kv_cache_tensor_meta.page_num + self.token_page_size: int = self.kv_cache_tensor_meta.token_page_size + self.allocated_pages_dict: dict[int, PastKVCacheItem] = {} + self.free_pages: List[int] = list(range(self.page_num)) + self.lock = Lock() + self.cond = Condition(self.lock) + + if not only_create_meta_data: + if init_shm_data: + self._create_shm_cpu_kv_cache() + self.attach_shm_handle = None + else: + self.attach_shm_handle = self._attach_shm_cpu_kv_cache() + self.attach_shm_handle.wait() + return + + def allocate_pages(self, req_id: int, need_tokens: int) -> List[int]: + need_pages = (need_tokens + self.token_page_size - 1) // self.token_page_size + if need_pages > self.page_num: + logger.error( + f"Request {req_id} need {need_tokens} tokens, which requires {need_pages} pages, " + f"exceeds the total page number {self.page_num}" + ) + raise ValueError(f"error allocate pages for request {req_id} with {need_tokens} tokens") + + with self.cond: + while len(self.free_pages) < need_pages: + self.cond.wait() + + page_indexes, self.free_pages = self.free_pages[:need_pages], self.free_pages[need_pages:] + self.allocated_pages_dict[req_id] = PastKVCacheItem( + req_id=req_id, token_len=need_tokens, page_indexes=page_indexes) + + return page_indexes + + def free_pages_by_req_id(self, req_id: int): + with self.cond: + item = self.allocated_pages_dict.pop(req_id, None) + if item is not None: + self.free_pages.extend(item.page_indexes) + self.cond.notify_all() + + def get_pages_by_req_id(self, req_id: int) -> Optional[List[int]]: + with self.lock: + item = self.allocated_pages_dict.get(req_id, None) + return item.page_indexes if item is not None else None + + def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optional[torch.Tensor]: + if page_indexes is None: + return None + assert token_num <= len(page_indexes) * self.token_page_size and \ + token_num > (len(page_indexes) - 1) * self.token_page_size + (P, L, S, H, D) = self.cpu_kv_cache_tensor[page_indexes].shape + # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (2, L, H // 2, P, S, D) + # -> (2, L, H // 2, P * S, D) -> ( L, 2, H // 2, P * S, D) + kv = self.cpu_kv_cache_tensor[page_indexes] \ + .view(P, L, S, 2, H // 2, D) \ + .permute(3, 1, 4, 0, 2, 5).contiguous() \ + .view(2, L, H // 2, P * S, D) \ + .permute(1, 0, 2, 3, 4) + return kv[:, :, :, :token_num, :].contiguous() + + def _create_shm_cpu_kv_cache(self): + shm_ptr = create_shm_kv_cache_ptr( + key=self.args.multi_modal_x2i_cache_shm_id, size=self.kv_cache_tensor_meta.calcu_size() + ) + numpy_array = np.frombuffer( + memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8 + ) + # 将 NumPy 数组转换为 PyTorch 张量 + shape = ( + self.kv_cache_tensor_meta.page_num, + self.kv_cache_tensor_meta.layer_num, + self.kv_cache_tensor_meta.token_page_size, + self.kv_cache_tensor_meta.num_heads, + self.kv_cache_tensor_meta.get_merged_head_dim(), + ) + self.cpu_kv_cache_tensor = ( + torch.from_numpy(numpy_array).view(dtype=self.kv_cache_tensor_meta.data_type).view(shape) + ) + return + + def _attach_shm_cpu_kv_cache(self): + shm_ptr = attach_shm_kv_cache_ptr( + key=self.args.multi_modal_x2i_cache_shm_id, size=self.kv_cache_tensor_meta.calcu_size() + ) + handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.kv_cache_tensor_meta.calcu_size()) + numpy_array = np.frombuffer( + memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8 + ) + shape = ( + self.kv_cache_tensor_meta.page_num, + self.kv_cache_tensor_meta.layer_num, + self.kv_cache_tensor_meta.token_page_size, + self.kv_cache_tensor_meta.num_heads, + self.kv_cache_tensor_meta.get_merged_head_dim(), + ) + self.cpu_kv_cache_tensor = ( + torch.from_numpy(numpy_array).view(dtype=self.kv_cache_tensor_meta.data_type).view(shape) + ) + assert shm_ptr == self.cpu_kv_cache_tensor.data_ptr() + + return handle diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py index 10764e24b0..de0e3f76ca 100644 --- a/lightllm/utils/kv_cache_utils.py +++ b/lightllm/utils/kv_cache_utils.py @@ -59,7 +59,7 @@ def compute_token_list_hash(tokens: List[int], cpu_cache_token_page_size: int) - @lru_cache(maxsize=None) def calcu_cpu_cache_meta() -> "CpuKVCacheMeta": args = get_env_start_args() - assert args.enable_cpu_cache + assert args.enable_cpu_cache or args.enable_multimodal_x2i mem_manager_class = select_mem_manager_class() if mem_manager_class is Deepseek2MemoryManager: diff --git a/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py b/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py new file mode 100644 index 0000000000..ca6c5043c9 --- /dev/null +++ b/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py @@ -0,0 +1,212 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu_all + +# ========================================================= +# GPU guard +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available" +) +# ========================================================= +# 工具函数:生成 GPU KV cache +def make_kv(L, T, H, D, device="cuda"): + x = torch.arange(L * T * H * D, dtype=torch.float32, device=device) + return x.view(L, T, H, D) + + +# ========================================================= +# Test 1: 基础功能(每 token 对应一个 page slot) +def test_basic_copy(): + L, T, H, D = 2, 8, 4, 16 + B = 1 # token_block_size + P = 4 # all_page_num + + gpu_kv = make_kv(L, T, H, D) + cpu_kv = torch.zeros((P, L, B, H, D), dtype=torch.float32, pin_memory=True) + + token_indexes = torch.tensor([1, 3, 5, 7], device="cuda") + page_indexes = torch.tensor([0, 1, 2, 3], device="cuda") + + offload_gpu_kv_to_cpu_all( + token_indexes, + gpu_kv, + None, + cpu_kv, + None, + page_indexes, + tp_index=0, + tp_world_size=1, + grid_num=1, + ) + + for i in range(len(token_indexes)): + token = token_indexes[i].item() + page = page_indexes[i].item() + + expected = gpu_kv[:, token, :, :].cpu() # (L,H,D) + actual = cpu_kv[page, :, 0, :, :] # (L,H,D) + + assert torch.allclose(expected, actual) + + +# ========================================================= +# Test 2: token乱序 +def test_random_tokens(): + L, T, H, D = 2, 16, 4, 8 + B = 1 + P = 6 + + gpu_kv = make_kv(L, T, H, D) + cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True) + + token_indexes = torch.tensor([10, 2, 7, 15, 0, 3], device="cuda") + page_indexes = torch.arange(6, device="cuda") + + offload_gpu_kv_to_cpu_all( + token_indexes, + gpu_kv, + None, + cpu_kv, + None, + page_indexes, + tp_index=0, + tp_world_size=1, + grid_num=1, + ) + + for i in range(6): + t = token_indexes[i].item() + p = page_indexes[i].item() + + assert torch.allclose( + cpu_kv[p, :, 0, :, :], + gpu_kv[:, t, :, :].cpu() + ) + + +# ========================================================= +# Test 3: 带 scale +def test_with_scale(): + L, T, H, D = 2, 8, 4, 16 + B = 1 + P = 3 + + gpu_kv = make_kv(L, T, H, D) + gpu_scale = torch.ones((L, T, H, D // 8), device="cuda") * 2.0 + + cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True) + cpu_scale = torch.zeros((P, L, B, H, D // 8), pin_memory=True) + + token_indexes = torch.tensor([1, 2, 3], device="cuda") + page_indexes = torch.tensor([0, 1, 2], device="cuda") + + offload_gpu_kv_to_cpu_all( + token_indexes, + gpu_kv, + gpu_scale, + cpu_kv, + cpu_scale, + page_indexes, + tp_index=0, + tp_world_size=1, + grid_num=1, + ) + + for i in range(3): + t = token_indexes[i].item() + p = page_indexes[i].item() + + # KV + assert torch.allclose( + cpu_kv[p, :, 0, :, :], + gpu_kv[:, t, :, :].cpu() + ) + + # scale + assert torch.allclose( + cpu_scale[p, :, 0, :], + gpu_scale[:, t, :].cpu() + ) + + +# ========================================================= +# Test 4: Tensor Parallel (按 head 切) +def test_tp_split(): + L, T, H, D = 2, 8, 4, 16 + B = 1 + P = 2 + tp_world_size = 2 + + gpu_kv = make_kv(L, T, H, D) + + cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True) + + token_indexes = torch.tensor([1, 2], device="cuda") + page_indexes = torch.tensor([0, 1], device="cuda") + + offload_gpu_kv_to_cpu_all( + token_indexes, gpu_kv, None, + cpu_kv, None, + page_indexes, + tp_index=0, + tp_world_size=tp_world_size, + grid_num=1, + ) + + offload_gpu_kv_to_cpu_all( + token_indexes, gpu_kv, None, + cpu_kv, None, + page_indexes, + tp_index=1, + tp_world_size=tp_world_size, + grid_num=1, + ) + + split = H // tp_world_size + + for i in range(2): + t = token_indexes[i].item() + p = page_indexes[i].item() + + assert torch.allclose( + cpu_kv[p, :, 0, :split, :], + gpu_kv[:, t, :split, :].cpu() + ) + + assert torch.allclose( + cpu_kv[p, :, 0, split:, :], + gpu_kv[:, t, split:, :].cpu() + ) + + +# ========================================================= +# Test 5: 空输入 +def test_empty(): + L, T, H, D = 2, 8, 4, 16 + B = 1 + P = 2 + + gpu_kv = make_kv(L, T, H, D) + cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True) + + token_indexes = torch.tensor([], dtype=torch.long, device="cuda") + page_indexes = torch.tensor([], dtype=torch.long, device="cuda") + + offload_gpu_kv_to_cpu_all( + token_indexes, + gpu_kv, + None, + cpu_kv, + None, + page_indexes, + tp_index=0, + tp_world_size=1, + grid_num=1, + ) + + assert torch.all(cpu_kv == 0) + +if __name__ == "__main__": + pytest.main() \ No newline at end of file From 142e3df42a2a595066727c8c6fce2297f5c7d245 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 25 Mar 2026 07:47:18 +0000 Subject: [PATCH 04/41] add naive x2i backend. --- lightllm/models/neo_chat_moe/model.py | 4 +- lightllm/server/api_cli.py | 6 + lightllm/server/api_start.py | 7 +- lightllm/server/core/objs/start_args_type.py | 3 + .../core/objs/token_chunck_hash_list.py | 10 + lightllm/server/core/objs/x2i_params.py | 58 +- lightllm/server/httpserver/manager.py | 38 +- lightllm/server/multimodal_params.py | 13 + lightllm/server/x2i_server/manager.py | 61 +- .../naive/configuration_neo_chat.py | 77 ++ .../x2i_server/naive/configuration_neo_vit.py | 52 + .../x2i_server/naive/modeling_fm_modules.py | 591 +++++++++ .../x2i_server/naive/modeling_neo_chat.py | 755 +++++++++++ .../x2i_server/naive/modeling_neo_vit.py | 235 ++++ .../server/x2i_server/naive/modeling_qwen3.py | 1117 +++++++++++++++++ .../server/x2i_server/past_kv_cache_client.py | 14 +- 16 files changed, 2995 insertions(+), 46 deletions(-) create mode 100644 lightllm/server/x2i_server/naive/configuration_neo_chat.py create mode 100644 lightllm/server/x2i_server/naive/configuration_neo_vit.py create mode 100644 lightllm/server/x2i_server/naive/modeling_fm_modules.py create mode 100644 lightllm/server/x2i_server/naive/modeling_neo_chat.py create mode 100644 lightllm/server/x2i_server/naive/modeling_neo_vit.py create mode 100644 lightllm/server/x2i_server/naive/modeling_qwen3.py diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index d9f40d7feb..edabc06075 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -47,7 +47,6 @@ def __init__(self, tokenizer, model_cfg, **kwargs): def load_conversion_module(self, model_dir: str): import importlib - conversion_path = os.path.join(model_dir, "conversation.py") if not os.path.exists(conversion_path): return None @@ -155,8 +154,7 @@ def get_query_for_it2i(self, prompt: str): def get_query_for_t2i(self, prompt): query_condition = self._build_t2i_query( f"Please generate an image based on the following description: {prompt}", - thinking_content="\n\n\n\n", - ) + thinking_content="\n\n\n\n") query_uncondition = self._build_t2i_query(f"") return query_condition, query_uncondition diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 06d8c57320..4c4be78b3c 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -307,6 +307,12 @@ def make_argument_parser() -> argparse.ArgumentParser: action="store_true", help="Whether or not to allow to generate images (requird --enable_multimodal)." ) + parser.add_argument( + "--x2i_server_used_gpus", + type=int, + default=1, + help="Number of GPUs to use for x2i server (requird --enable_multimodal_x2i).", + ) parser.add_argument( "--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service." ) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index dc347e4929..358ae425b4 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -353,8 +353,9 @@ def normal_or_p_d_start(args): ) if args.enable_multimodal_x2i: - from .x2i_server.manager import start_x2i_process - + from .x2i_server.manager import start_x2i_process, setup_devices + origin_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + setup_devices(args) process_manager.start_submodule_processes( start_funcs=[ start_x2i_process, @@ -363,6 +364,8 @@ def normal_or_p_d_start(args): (args,), ], ) + if origin_devices: + os.environ["CUDA_VISIBLE_DEVICES"] = origin_devices if args.enable_cpu_cache: from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index b741847b4c..5cf13ef19a 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -162,4 +162,7 @@ class StartArgs: multi_level_kv_cache_port: int = field(default=None) x2i_port: int = field(default=None) http_server_port_for_x2i: int = field(default=None) + x2i_server_used_gpus: int = field(default=1) + + # multi_modal enable_multimodal_x2i: bool = field(default=False) diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py index 479c6c17ff..b16f264936 100644 --- a/lightllm/server/core/objs/token_chunck_hash_list.py +++ b/lightllm/server/core/objs/token_chunck_hash_list.py @@ -94,8 +94,18 @@ class PastKVCachePageList(CpuCachePageList): _pack_ = 4 _fields_ = CpuCachePageList._fields_ +[ ("token_len", ctypes.c_int), # 对应的token数量 + ("img_tokens", ctypes.c_int), + ("img_len", ctypes.c_int) ] def __init__(self, token_len: int = 0): super().__init__() self.token_len = token_len + self.img_tokens = 0 + self.img_len = 0 + + def get_compressed_len(self): + return self.token_len - self.img_tokens + self.img_len + + def __repr__(self): + return f"(token_len={self.token_len}, img_tokens={self.img_tokens}, img_len={self.img_len})" \ No newline at end of file diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py index 66fcbe9afb..50b2595ec2 100644 --- a/lightllm/server/core/objs/x2i_params.py +++ b/lightllm/server/core/objs/x2i_params.py @@ -1,6 +1,6 @@ import ctypes from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Optional from enum import IntEnum from .token_chunck_hash_list import PastKVCachePageList @@ -8,6 +8,21 @@ class CfgNormType(IntEnum): NONE = 0 CFG_ZERO_STAR = 1 GLOBAL = 2 + TEXT_CHANNEL = 3 + CHANNEL = 4 + + def as_str(self) -> str: + mapping = { + CfgNormType.NONE: "none", + CfgNormType.CFG_ZERO_STAR: "cfg_zero_star", + CfgNormType.GLOBAL: "global", + CfgNormType.TEXT_CHANNEL: "text_channel", + CfgNormType.CHANNEL: "channel", + } + return mapping[self] + + def __repr__(self): + return self.as_str() class X2IParams(ctypes.Structure): @@ -28,9 +43,9 @@ class X2IParams(ctypes.Structure): ("request_id", ctypes.c_int64), ] - _width: int = 512 - _height: int = 512 - _steps: int = 30 + _width: int = 1024 + _height: int = 1024 + _steps: int = 50 _guidance_scale: float = 7.0 _image_guidance_scale: float = 7.0 _seed: int = 42 @@ -56,8 +71,11 @@ def _get(key, default): self.request_id = 0 def update(self, past_kv: PastKVCachePageList, meta: Dict): - past_kv.token_len = meta.get("prompt_tokens") - past_kv.fill(meta.get("kv_cache_pages")) + item: PastKVCacheItem = meta.get("kv_cache_item") + past_kv.token_len = item.token_len + past_kv.img_tokens = item.img_tokens + past_kv.img_len = item.img_len + past_kv.fill(item.page_indexes) self.total_prompt_tokens += past_kv.token_len def update_t2i(self, meta, meta_uncond): @@ -69,12 +87,36 @@ def update_it2i(self, meta, meta_text_uncond, meta_img_uncond): self.update(self.past_kvcache_text, meta_text_uncond) self.update(self.past_kvcache_img, meta_img_uncond) + def get_cfg_norm(self): + return CfgNormType(self.cfg_norm).as_str() + + def to_string(self): + parts = [] + for field_name, _ in self._fields_: + value = getattr(self, field_name) + parts.append(f"{field_name}={value}") + + return "X2IParams(" + ", ".join(parts) + ")" + + def __repr__(self): + return self.to_string() + @dataclass class X2IResponse: request_id: int - images: List[bytes] + images: Optional[List[bytes]] + @dataclass class X2ICacheRelease: - request_id: int \ No newline at end of file + request_id: int + + +@dataclass +class PastKVCacheItem: + req_id: int + token_len: int + img_tokens: int + img_len: int + page_indexes: List[int] \ No newline at end of file diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 1c943e4937..9502d630f0 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -25,7 +25,7 @@ from .async_queue import AsyncQueue from lightllm.server.core.objs import Req, FinishStatus, StartArgs from lightllm.server.core.objs import SamplingParams -from lightllm.server.core.objs.x2i_params import X2IParams, X2ICacheRelease, X2IResponse +from lightllm.server.core.objs.x2i_params import X2IParams, X2ICacheRelease, X2IResponse, PastKVCacheItem from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE from lightllm.server.core.objs.io_objs import GroupReqObjs from lightllm.server.core.objs.shm_req_manager import ShmReqManager @@ -369,8 +369,10 @@ async def generate( # allocate pages, may block if cache is full, but it won't cause deadlock # because the prefill process is designed to be sequential and the pages # will be released after prefill. + img_tokens = sum([img.token_num for img in multimodal_params.images]) + img_len = len(multimodal_params.images) kv_pages = self.past_kv_cache_client.allocate_pages( - req_obj.request_id, req_obj.input_len) + req_obj.request_id, req_obj.input_len, img_tokens, img_len) req_obj.past_kv_cache_page_indexes.fill(kv_pages) @@ -417,7 +419,7 @@ async def generate( yield sub_req_id, request_output, metadata, finish_status except Exception as e: - logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + logger.error(f"group_request_id: {group_request_id} has exception {str(e)}", exc_info=e) # error need to release multimodel resources. # 对于还没有形成正式请求对象管理的多模态资源,需要单独自己释放 # 已经放入到 req_id_to_out_inf 中的请求对象,由统一的回收循环 @@ -441,35 +443,36 @@ async def generation_wrapper(prompt, sample, multimodal, request): async for sub_req_id, _, metadata, finish_status in self.generate( prompt, sample, multimodal, request ): - kv_cache_pages = self.past_kv_cache_client.get_pages_by_req_id(sub_req_id) - if kv_cache_pages is None: + kv_cache_item: PastKVCacheItem = self.past_kv_cache_client.get_pages_by_req_id(sub_req_id) + if kv_cache_item is None: raise Exception(f"kv_cache_pages is None for sub_req_id {sub_req_id}") - metadata["kv_cache_pages"] = kv_cache_pages + metadata["kv_cache_item"] = kv_cache_item metadata["request_id"] = sub_req_id metadata["finish_status"] = finish_status generate_req_ids.append(sub_req_id) return metadata try: - # 1. construct 3 or 2 images based on the multimodel_parmas + # 1. construct 3 or 2 generate based on the multimodel_parmas sample_params = SamplingParams() sample_params.init(self.tokenizer, **{"img_gen_prefill": True}) img_len = len(multimodal_params.images) if img_len > 0: # call it2i - prompt_condition = f"Please generate an image based on the following instruction: {prompt}" - prompt_text_uncondition = "Please generate an image based on the following instruction: "+ '\n' * img_len - prompt_img_uncondition = "Please generate an image based on the following instruction: " + re.sub(r"\n?", "", prompt) + # fix prompt, add tag if img_len greater than s in prompt + prompt = self.tokenizer.fix_prompt(prompt, img_len) + + prompt_condition, prompt_text_uncondition, prompt_img_uncondition = self.tokenizer.get_query_for_it2i(prompt) (con_gen, text_uncon_gen, img_uncon_gen) = await asyncio.gather(*[ generation_wrapper(prompt_condition, sample_params, multimodal_params, request), - generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params, request), + generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params.clone(), request), generation_wrapper(prompt_img_uncondition, sample_params, MultimodalParams(), request)]) generation_params.update_it2i(con_gen, text_uncon_gen, img_uncon_gen) else: # call t2i - prompt_condition = f"Please generate an image based on the following caption: {prompt}" - prompt_uncondition = f"Please generate an image based on the following caption: " + prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt) + logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}") (con_gen, uncon_gen) = await asyncio.gather(*[ generation_wrapper(prompt_condition, sample_params, multimodal_params, request), generation_wrapper(prompt_uncondition, sample_params, multimodal_params, request)]) @@ -494,7 +497,7 @@ async def generation_wrapper(prompt, sample, multimodal, request): except Exception as e: logger.error(str(e)) - pass + return [] finally: for req_id in generate_req_ids: @@ -844,13 +847,18 @@ async def loop_for_x2i(self): elif isinstance(recv_obj, X2IResponse): status = self.req_id_to_x2i_reqs[recv_obj.request_id] + if recv_obj.images is None: + for req_id in status.req_ids: + self.past_kv_cache_client.free_pages_by_req_id(req_id) + status.response = recv_obj status.event.set() except asyncio.TimeoutError: pass + except Exception as e: - logger.error(e) + logger.error(e, exc_info=e) async def handle_loop(self): diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 09a07455b3..999dfc6859 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -184,3 +184,16 @@ def to_origin_dict(self): ret["images"] = [i.to_origin_dict() for i in self.images] ret["audios"] = [a.to_origin_dict() for a in self.audios] return ret +<<<<<<< HEAD +======= + + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + return + + def clone(self): + return MultimodalParams(**self.to_origin_dict()) +>>>>>>> 43a488a8 (add naive x2i backend.) diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index 3dc82be897..11e53b5d67 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -6,6 +6,7 @@ import setproctitle import pickle import torch +import os from typing import List from lightllm.server.core.objs import StartArgs @@ -15,10 +16,12 @@ from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease +from lightllm.utils.dist_utils import set_current_device_id from .past_kv_cache_client import PastKVCacheClient logger = init_logger(__name__) + ''' manage a generation service, 1. start x2v pipelines @@ -44,10 +47,6 @@ def __init__( self.waiting_reqs: List[X2IParams] = [] - from lightllm.utils.dist_utils import set_current_device_id - - set_current_device_id(torch.cuda.current_device()) - self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True) async def wait_to_model_ready(self): @@ -61,8 +60,21 @@ async def wait_to_model_ready(self): # config_json = self.args.x2v_gen_model_config, # ) + from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I + + self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device()) + pass + async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams): + images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param) + return images + + async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams): + images = self.naive_x2i.it2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img, param) + return images + + async def loop_for_fwd(self): while True: try: @@ -81,8 +93,6 @@ async def loop_for_fwd(self): ) is_t2i = x2i_param.past_kvcache_img.is_empty() - logger.info(f"past kv cache shape: {past_kv_cache.shape}, past_kv_cache_text shape: {past_kv_cache_text.shape}") - past_kv_cache_img = None if not is_t2i: # t2i past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i( @@ -94,13 +104,24 @@ async def loop_for_fwd(self): X2ICacheRelease(request_id=x2i_param.request_id), protocol=pickle.HIGHEST_PROTOCOL) - # call generate images + images = [] + logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with x2i_param: {x2i_param}") + if is_t2i: + images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param) + else: + images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param) + self.send_to_httpserver.send_pyobj(X2IResponse( request_id=x2i_param.request_id, - images=[]), + images=images), protocol=pickle.HIGHEST_PROTOCOL) except Exception as e: + self.send_to_httpserver.send_pyobj(X2IResponse( + request_id=x2i_param.request_id, + images=None), + protocol=pickle.HIGHEST_PROTOCOL) + logger.error(e) @@ -115,11 +136,35 @@ async def loop_for_netio_req(self): await asyncio.sleep(0.01) + def clean_up(self): + pass + +def setup_devices(args: StartArgs): + devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip() + logger.info(f"current devices: {devices} {torch.cuda.device_count()}") + if not devices: + devices = list(range(torch.cuda.device_count())) + else: + devices = [int(x.strip()) for x in devices.split(",") if x.strip()] + + llm_need_gpus = args.tp * args.dp + x2i_need_gpus = args.x2i_server_used_gpus + if len(devices) < llm_need_gpus + x2i_need_gpus: + raise ValueError(f"devices {devices} not enough, need {llm_need_gpus} and {x2i_need_gpus}") + + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices[ + llm_need_gpus:llm_need_gpus + x2i_need_gpus])) + + logger.info(f"setup devices for x2i server: {os.environ['CUDA_VISIBLE_DEVICES']}, " + f"{torch.cuda.device_count()} {torch.cuda.current_device()}") + + def start_x2i_process(args, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::x2i_server") start_parent_check_thread() + set_current_device_id(torch.cuda.current_device()) try: x2iserver = X2IManager(args=args,) asyncio.run(x2iserver.wait_to_model_ready()) diff --git a/lightllm/server/x2i_server/naive/configuration_neo_chat.py b/lightllm/server/x2i_server/naive/configuration_neo_chat.py new file mode 100644 index 0000000000..44171917a5 --- /dev/null +++ b/lightllm/server/x2i_server/naive/configuration_neo_chat.py @@ -0,0 +1,77 @@ +import copy + +from transformers import Qwen3Config +from transformers.utils import logging +from transformers.configuration_utils import PretrainedConfig + +from .configuration_neo_vit import NEOVisionConfig + + +logger = logging.get_logger(__name__) + + +class NEOLLMConfig(Qwen3Config): + def __init__(self, rope_theta_hw=10000.0, max_position_embeddings_hw=10000, **kwargs): + super().__init__(**kwargs) + self.rope_theta_hw = rope_theta_hw + self.max_position_embeddings_hw = max_position_embeddings_hw + + +class NEOChatConfig(PretrainedConfig): + model_type = 'neo_chat' + is_composition = True + + def __init__( + self, + vision_config=None, + llm_config=None, + use_backbone_lora=0, + use_llm_lora=0, + downsample_ratio=0.5, + template=None, + **kwargs, + ): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {'architectures': ['NEOVisionModel']} + logger.info('vision_config is None. Initializing the NEOVisionConfig with default values.') + + if llm_config is None: + llm_config = {'architectures': ['Qwen3ForCausalLM']} + logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).') + assert 'architectures' in llm_config, "Should specify architecture in llm_config" + + if isinstance(vision_config, dict): + self.vision_config = NEOVisionConfig(**vision_config) + else: + self.vision_config = vision_config + + if isinstance(llm_config, dict): + self.llm_config = NEOLLMConfig(**llm_config) + else: + self.llm_config = llm_config + + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.downsample_ratio = downsample_ratio + self.template = template + self.tie_word_embeddings = self.llm_config.tie_word_embeddings + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. + + Returns: + `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, + """ + output = copy.deepcopy(self.__dict__) + output['vision_config'] = self.vision_config.to_dict() + output['llm_config'] = self.llm_config.to_dict() + output['model_type'] = self.__class__.model_type + output['use_backbone_lora'] = self.use_backbone_lora + output['use_llm_lora'] = self.use_llm_lora + output['downsample_ratio'] = self.downsample_ratio + output['template'] = self.template + + return output diff --git a/lightllm/server/x2i_server/naive/configuration_neo_vit.py b/lightllm/server/x2i_server/naive/configuration_neo_vit.py new file mode 100644 index 0000000000..02837fea41 --- /dev/null +++ b/lightllm/server/x2i_server/naive/configuration_neo_vit.py @@ -0,0 +1,52 @@ +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class NEOVisionConfig(PretrainedConfig): + + model_type = 'neo_vision' + + def __init__( + self, + num_channels=3, + patch_size=16, + hidden_size=1024, + llm_hidden_size=2048, + downsample_ratio=0.5, + rope_theta_vision=10000.0, + max_position_embeddings_vision=10000, + min_pixels=65536, + max_pixels=4194304, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.llm_hidden_size = llm_hidden_size, + self.downsample_ratio = downsample_ratio, + self.rope_theta_vision = rope_theta_vision + self.max_position_embeddings_vision = max_position_embeddings_vision + self.num_channels = num_channels + self.patch_size = patch_size + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + if 'vision_config' in config_dict: + config_dict = config_dict['vision_config'] + + if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' + ) + + return cls.from_dict(config_dict, **kwargs) \ No newline at end of file diff --git a/lightllm/server/x2i_server/naive/modeling_fm_modules.py b/lightllm/server/x2i_server/naive/modeling_fm_modules.py new file mode 100644 index 0000000000..56654b2580 --- /dev/null +++ b/lightllm/server/x2i_server/naive/modeling_fm_modules.py @@ -0,0 +1,591 @@ +import numpy as np +import torch +import torch.nn as nn +import math +from functools import lru_cache + +from torch.utils.checkpoint import checkpoint +def modulate(x, shift, scale=None): + if shift is None: + return x * (1 + scale) + return x * (1 + scale) + shift + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + return output * self.weight + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=t.device + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) + return t_emb + +class ResBlock(nn.Module): + + def __init__(self, channels, mlp_ratio=1.0): + super().__init__() + self.channels = channels + self.intermediate_size = int(channels * mlp_ratio) + + self.in_ln = nn.LayerNorm(self.channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(self.channels, self.intermediate_size), + nn.SiLU(), + nn.Linear(self.intermediate_size, self.channels), + ) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)) + + def forward(self, x, y): + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h + +# class FinalLayer(nn.Module): + +# def __init__(self, model_channels, out_channels): +# super().__init__() +# self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) +# self.linear = nn.Linear(model_channels, out_channels, bias=True) +# self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True)) + +# def forward(self, x, c): +# shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) +# x = modulate(self.norm_final(x), shift, scale) +# x = self.linear(x) +# return x + +# class SimpleMLPAdaLN(nn.Module): + +# def __init__(self, input_dim, out_dim, dim=1536, layers=12, mlp_ratio=1.0): +# super().__init__() +# self.input_dim = input_dim +# self.out_dim = out_dim +# self.dim = dim +# self.layers = layers +# self.mlp_ratio = mlp_ratio + +# self.time_embed = TimestepEmbedder(dim) +# self.input_proj = nn.Linear(input_dim, dim) + +# res_blocks = [] +# for _ in range(layers): +# res_blocks.append(ResBlock(dim, mlp_ratio)) +# self.res_blocks = nn.ModuleList(res_blocks) + +# self.final_layer = FinalLayer(dim, out_dim) + +# self.grad_checkpointing = False + +# self.initialize_weights() + +# def initialize_weights(self): +# def _basic_init(module): +# if isinstance(module, nn.Linear): +# torch.nn.init.xavier_uniform_(module.weight) +# if module.bias is not None: +# nn.init.constant_(module.bias, 0) + +# self.apply(_basic_init) + +# # Initialize timestep embedding MLP +# nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) +# nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) + +# # Zero-out adaLN modulation layers +# for block in self.res_blocks: +# nn.init.constant_(block.adaLN_modulation[-1].weight, 0) +# nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + +# # Zero-out output layers +# nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) +# nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) +# nn.init.constant_(self.final_layer.linear.weight, 0) +# nn.init.constant_(self.final_layer.linear.bias, 0) + +# def forward(self, x, t): +# """ +# x.shape = (bsz, input_dim) +# t.shape = (bsz,) +# """ + +# x = self.input_proj(x) +# t = self.time_embed(t) + +# y = t + +# for block in self.res_blocks: +# if self.grad_checkpointing and self.training: +# x = checkpoint(block, x, y, use_reentrant=True) +# else: +# x = block(x, y) + +# return self.final_layer(x, y) + +class FlowMatchingHead(nn.Module): + + def __init__(self, input_dim, out_dim, dim=1536, layers=12, mlp_ratio=1.0): + super(FlowMatchingHead, self).__init__() + self.net = SimpleMLPAdaLN(input_dim=input_dim, out_dim=out_dim, dim=dim, layers=layers, mlp_ratio=mlp_ratio) + + @property + def dtype(self): + return self.net.input_proj.weight.dtype + + @property + def device(self): + return self.net.input_proj.weight.device + + def forward(self, x, t): + x = self.net(x, t) + return x + + +def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): + # assert H * H == end + # flat_patch_pos = torch.linspace(-1, 1, end) # N = end + x_pos = torch.linspace(0, scale, width) + y_pos = torch.linspace(0, scale, height) + y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") + y_pos = y_pos.reshape(-1) + x_pos = x_pos.reshape(-1) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) + y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height*width, -1) + return freqs_cis + +class NerfEmbedder(nn.Module): + def __init__(self, in_channels, hidden_size_input, max_freqs): + super().__init__() + self.max_freqs = max_freqs + self.hidden_size_input = hidden_size_input + self.embedder = nn.Sequential( + nn.Linear(in_channels+max_freqs**2, hidden_size_input, bias=True), + ) + + @lru_cache + def fetch_pos(self, patch_size, device, dtype): + pos = precompute_freqs_cis_2d(self.max_freqs ** 2 * 2, patch_size, patch_size).real + pos = pos[None, :, :].to(device=device, dtype=dtype) + return pos + + + def forward(self, inputs): + B, P2, C = inputs.shape + patch_size = int(P2 ** 0.5) + device = inputs.device + dtype = inputs.dtype + dct = self.fetch_pos(patch_size, device, dtype) + dct = dct.repeat(B, 1, 1) + inputs = torch.cat([inputs, dct], dim=-1) + inputs = self.embedder(inputs) + return inputs + +class SimpleMLPAdaLN(nn.Module): + """ + The MLP for Diffusion Loss. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param z_channels: channels in the condition. + :param num_res_blocks: number of residual blocks per downsample. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + z_channels, + num_res_blocks, + patch_size, + grad_checkpointing=False + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.grad_checkpointing = grad_checkpointing + self.patch_size = patch_size + + self.cond_embed = nn.Linear(z_channels, patch_size**2*model_channels) + + self.input_proj = nn.Linear(in_channels, model_channels) + + res_blocks = [] + for i in range(num_res_blocks): + res_blocks.append(ResBlock( + model_channels, + )) + + self.res_blocks = nn.ModuleList(res_blocks) + self.final_layer = FinalLayer(model_channels, out_channels) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Zero-out adaLN modulation layers + for block in self.res_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, c): + """ + Apply the model to an input batch. + :param x: an [N x C] Tensor of inputs. + :param t: a 1-D batch of timesteps. + :param c: conditioning from AR transformer. + :return: an [N x C] Tensor of outputs. + """ + x = self.input_proj(x) + c = self.cond_embed(c) + + y = c.reshape(-1, self.patch_size**2, self.model_channels) + + for block in self.res_blocks: + x = block(x, y) + + return self.final_layer(x) + + +class FinalLayer(nn.Module): + """ + The final layer adopted from DiT. + """ + def __init__(self, model_channels, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + + def forward(self, x): + x = self.norm_final(x) + x = self.linear(x) + return x + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) / pe_interpolation + grid_w = np.arange(grid_size, dtype=np.float32) / pe_interpolation + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model_path, pe_key: str = "gen_pos_embed", new_len: int = 4096): + state_dict = torch.load(model_path, map_location="cpu") + + pos_embed_1d = state_dict[pe_key] + _, ori_len, embed_dim = pos_embed_1d.shape + + ori_size = int(ori_len**0.5) + new_size = int(new_len**0.5) + + if ori_size != new_size: + logger.info("Position interpolate from %dx%d to %dx%d" % (ori_size, ori_size, new_size, new_size)) + pos_embed_2d = pos_embed_1d.reshape(-1, ori_size, ori_size, embed_dim).permute(0, 3, 1, 2) + pos_embed_2d = torch.nn.functional.interpolate( + pos_embed_2d, size=(new_size, new_size), mode="bicubic", align_corners=False + ) + pos_embed_1d = pos_embed_2d.permute(0, 2, 3, 1).flatten(1, 2) + state_dict[pe_key] = pos_embed_1d + + torch.save(state_dict, model_path) + +class PositionEmbedding(nn.Module): + def __init__(self, max_num_patch_per_side, hidden_size): + super().__init__() + self.max_num_patch_per_side = max_num_patch_per_side + self.hidden_size = hidden_size + self.pos_embed = nn.Parameter( + torch.zeros(max_num_patch_per_side ** 2, hidden_size), + requires_grad=False + ) + self._init_weights() + + def _init_weights(self): + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float()) + + def forward(self, position_ids): + return self.pos_embed[position_ids] + + +class ResidualConvBlock(nn.Module): + def __init__(self, channels: int): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + nn.SiLU(), + nn.Conv2d(channels, channels, kernel_size=3, padding=1), + ) + nn.init.zeros_(self.block[2].weight) + nn.init.zeros_(self.block[2].bias) + + def forward(self, x): + return x + self.block(x) + + +class PostConvSmoother(nn.Module): + def __init__(self, in_channels=3, hidden_channels=64, num_blocks=3): + super().__init__() + self.in_proj = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1) + self.blocks = nn.Sequential(*[ResidualConvBlock(hidden_channels) for _ in range(num_blocks)]) + self.out_proj = nn.Conv2d(hidden_channels, in_channels, kernel_size=1) + + nn.init.zeros_(self.out_proj.weight) + nn.init.zeros_(self.out_proj.bias) + + def forward(self, x): + h = self.in_proj(x) + h = self.blocks(h) + return x + self.out_proj(h) + + +class ProgressiveConvDecoder(nn.Module): + def __init__(self, hidden_dim=4096, out_channels=3): + super().__init__() + + # self.proj = nn.Linear(hidden_dim, 1024) + # self.act = nn.SiLU() + + self.up_blocks = nn.ModuleList([ + nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv2d(hidden_dim, 512, kernel_size=3, padding=1), + nn.GroupNorm(32, 512), + nn.SiLU() + ), + nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv2d(512, 256, kernel_size=3, padding=1), + nn.GroupNorm(32, 256), + nn.SiLU() + ), + nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv2d(256, 64, kernel_size=3, padding=1), + nn.GroupNorm(32, 64), + nn.SiLU() + ), + nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.GroupNorm(16, 32), + nn.SiLU() + ), + nn.Sequential( + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv2d(32, 16, kernel_size=3, padding=1), + nn.SiLU() + ) + ]) + + self.out_conv = nn.Conv2d(16, out_channels, kernel_size=3, padding=1) + + def forward(self, x_2d): + # B, C, H, W = x_2d.shape + # x = x_2d.permute(0, 2, 3, 1).contiguous() # (B, H, W, C) + # x = self.proj(x) + # x = self.act(x) + # x = x.permute(0, 3, 1, 2).contiguous() # (B, 512, H, W) + x = x_2d + for block in self.up_blocks: + x = block(x) + + out = self.out_conv(x) + return out + + +class PatchDecoder_postps(nn.Module): + def __init__(self): + super().__init__() + # layer 1: H/32 -> H/8 (4x upscale) + + self.conv1 = nn.Conv2d(4096, 4096, kernel_size=3, padding=1) + self.ps1 = nn.PixelShuffle(4) + self.act1 = nn.GELU() + + # layer 2: H/8 -> H (8x upscale) + self.conv2 = nn.Conv2d(256, 192, kernel_size=3, padding=1) + self.ps2 = nn.PixelShuffle(8) + + def forward(self, x): + # x shape: [B, 4096, H/32, W/32] + x = self.ps1(self.act1(self.conv1(x))) # -> [B, 256, H/8, W/8] + x = self.ps2(self.conv2(x)) # -> [B, 3, H, W] + return x + + +class PatchDecoder_preps(nn.Module): + def __init__(self): + super().__init__() + # layer 1: H/32 -> H/16 (2x upscale) + self.ps1 = nn.PixelShuffle(2) + self.conv1 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1) + self.act1 = nn.GELU() + + # layer 2: H/16 -> H/8 (2x upscale) + self.ps2 = nn.PixelShuffle(2) + self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.act2 = nn.GELU() + + # layer 3: H/8 -> H (8x upscale) + self.ps3 = nn.PixelShuffle(8) + self.conv3 = nn.Conv2d(4, 3, kernel_size=3, padding=1) + + def forward(self, x): + # x shape: [B, 4096, H/32, W/32] + x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16] + x = self.act2(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8] + x = self.conv3(self.ps3((x))) # -> [B, 3, H, W] + return x + +class PatchDecoder_preps1(nn.Module): + def __init__(self): + super().__init__() + # layer 1: H/32 -> H/16 (2x upscale) + self.ps1 = nn.PixelShuffle(2) + self.conv1 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1) + self.act1 = nn.GELU() + + # layer 2: H/16 -> H/8 (2x upscale) + self.ps2 = nn.PixelShuffle(2) + self.conv2 = nn.Conv2d(256, 192, kernel_size=3, padding=1) + + # layer 3: H/8 -> H (8x upscale) + self.ps3 = nn.PixelShuffle(8) + + def forward(self, x): + # x shape: [B, 4096, H/32, W/32] + x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16] + x = self.ps3(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8] + return x + +class ConvDecoder(nn.Module): + def __init__(self, input_dim=4096, hidden_dim=1024): + super().__init__() + # layer 1: H/32 -> H/16 (2x upscale) + self.ps1 = nn.PixelShuffle(2) + self.conv1 = nn.Conv2d(input_dim // 4, hidden_dim, kernel_size=3, padding=1) + self.act1 = nn.GELU() + + # layer 2: H/16 -> H/8 (2x upscale) + self.ps2 = nn.PixelShuffle(2) + self.conv2 = nn.Conv2d(hidden_dim // 4, 192, kernel_size=3, padding=1) + + # layer 3: H/8 -> H (8x upscale) + self.ps3 = nn.PixelShuffle(8) + + def forward(self, x): + x = self.act1(self.conv1(self.ps1((x)))) + x = self.ps3(self.conv2(self.ps2((x)))) + return x diff --git a/lightllm/server/x2i_server/naive/modeling_neo_chat.py b/lightllm/server/x2i_server/naive/modeling_neo_chat.py new file mode 100644 index 0000000000..890d761391 --- /dev/null +++ b/lightllm/server/x2i_server/naive/modeling_neo_chat.py @@ -0,0 +1,755 @@ +from typing import List, Optional, Tuple, Union +import math +import torch.utils.checkpoint +from torch import nn +import transformers +import numpy as np +import base64 +from PIL import Image +from torch.nn import CrossEntropyLoss +from transformers import GenerationConfig +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.cache_utils import Cache, DynamicCache +from lightllm.server.core.objs.x2i_params import X2IParams +import torchvision.io as io + +from .configuration_neo_chat import NEOChatConfig +from .modeling_neo_vit import NEOVisionModel +from .modeling_qwen3 import Qwen3ForCausalLM, create_block_causal_mask +from .modeling_fm_modules import PositionEmbedding, TimestepEmbedder, FlowMatchingHead, RMSNorm, NerfEmbedder, SimpleMLPAdaLN, ConvDecoder + + +logger = logging.get_logger(__name__) + + +def version_cmp(v1, v2, op='eq'): + import operator + + from packaging import version + op_func = getattr(operator, op) + return op_func(version.parse(v1), version.parse(v2)) + +def prepare_flash_kv_cache( + past_key_values, + current_len: int, + batch_size: int, +): + """ + Convert prefix cache from [B, H, S, D] to flash-attn friendly [B, S, H, D], + and preallocate full KV buffer for [prefix + current]. + + This is done once before denoising loop. + """ + if past_key_values is None: + return + + for layer in past_key_values.layers: + past_k = layer.keys + past_v = layer.values + + if past_k is None or past_v is None: + layer.flash_prefix_len = 0 + layer.flash_total_len = current_len + layer.flash_k_cache = None + layer.flash_v_cache = None + continue + + # original cache layout assumed: [B, H, S, D] + past_k_flash = past_k.transpose(1, 2).contiguous() # [B, S, H, D] + past_v_flash = past_v.transpose(1, 2).contiguous() # [B, S, H, D] + + prefix_len = past_k_flash.shape[1] + total_len = prefix_len + current_len + + k_cache = torch.empty( + (batch_size, total_len, past_k_flash.shape[2], past_k_flash.shape[3]), + device=past_k_flash.device, + dtype=past_k_flash.dtype, + ) + v_cache = torch.empty( + (batch_size, total_len, past_v_flash.shape[2], past_v_flash.shape[3]), + device=past_v_flash.device, + dtype=past_v_flash.dtype, + ) + + k_cache[:, :prefix_len].copy_(past_k_flash) + v_cache[:, :prefix_len].copy_(past_v_flash) + + layer.flash_prefix_len = prefix_len + layer.flash_total_len = total_len + layer.flash_k_cache = k_cache + layer.flash_v_cache = v_cache + +def clear_flash_kv_cache(past_key_values): + if past_key_values is None: + return + for layer in past_key_values.layers: + if hasattr(layer, "flash_prefix_len"): + delattr(layer, "flash_prefix_len") + if hasattr(layer, "flash_total_len"): + delattr(layer, "flash_total_len") + if hasattr(layer, "flash_k_cache"): + delattr(layer, "flash_k_cache") + if hasattr(layer, "flash_v_cache"): + delattr(layer, "flash_v_cache") + + +@torch.cuda.amp.autocast(dtype=torch.float32) +def optimized_scale(positive_flat, negative_flat): + + # Calculate dot production + dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) + + # Squared norm of uncondition + squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 + + # st_star = v_cond^T * v_uncond / ||v_uncond||^2 + st_star = dot_product / squared_norm + + return st_star + + +def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): + """ + Compute patch coordinates (x, y) + + Args: + grid_hw: (B, 2) tensor representing (H, W) per image + """ + device = grid_hw.device + B = grid_hw.shape[0] + + # Get the number of patches per image + H = grid_hw[:, 0] + W = grid_hw[:, 1] + N = H * W + N_total = N.sum() + + # Create the batch index for each patch (B x patch count) + patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) + + # Generate intra-image patch index (row-major order) + patch_id_within_image = torch.arange(N_total, device=device) + patch_id_within_image = patch_id_within_image - torch.cumsum( + torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0 + )[patch_to_sample] + + # Get H/W for each patch according to its image + W_per_patch = W[patch_to_sample] + abs_x = patch_id_within_image % W_per_patch + abs_y = patch_id_within_image // W_per_patch + + return abs_x, abs_y + + +class NEOChatModel(PreTrainedModel): + config_class = NEOChatConfig + main_input_name = 'pixel_values' + base_model_prefix = 'language_model' + _supports_flash_attn_2 = True + supports_gradient_checkpointing = True + _no_split_modules = [ + "NEOVisionModel", + "Qwen3DecoderLayer", + ] + + # support transformers 4.51.+ + _tp_plan = '' + + def __init__(self, config: NEOChatConfig, vision_model=None, language_model=None, use_flash_attn=True): + super().__init__(config) + + assert version_cmp(transformers.__version__, '4.37.0', 'ge') + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + self.template = config.template + self.downsample_ratio = config.downsample_ratio + config.llm_config._attn_implementation = 'eager' + + if vision_model is not None: + self.vision_model = vision_model + else: + self.vision_model = NEOVisionModel(config.vision_config) + vision_model_mot_gen = NEOVisionModel(config.vision_config) + if language_model is not None: + self.language_model = language_model + else: + self.language_model = Qwen3ForCausalLM(config.llm_config) + + merge_size = int(1 / self.downsample_ratio) + output_dim = 3*(patch_size*merge_size)**2 + llm_hidden_size = self.config.llm_config.hidden_size + self.use_deep_fm_head = self.config.fm_head_layers > 2 + self.use_pixel_head = self.config.use_pixel_head + + if self.use_deep_fm_head: + fm_head = FlowMatchingHead(llm_hidden_size, output_dim, dim=self.config.fm_head_dim, layers=self.config.fm_head_layers, mlp_ratio=self.config.fm_head_mlp_ratio) + else: + fm_head = nn.Sequential( + nn.Linear(llm_hidden_size, 4096, bias=True), + nn.GELU(), + nn.Linear(4096, output_dim, bias=True), + ) + + timestep_embedder = TimestepEmbedder(llm_hidden_size) + self.fm_modules = nn.ModuleDict( + { + "vision_model_mot_gen": vision_model_mot_gen, + "timestep_embedder": timestep_embedder, + "fm_head": fm_head + } + ) + if self.use_pixel_head: + self.fm_modules["fm_head"] = ConvDecoder(llm_hidden_size) + + + self.concat_time_token_num = config.concat_time_token_num + self.time_token_id = 151682 + self.noise_scale = config.noise_scale + self.noise_scale_mode = config.noise_scale_mode + self.noise_scale_base_image_seq_len = config.noise_scale_base_image_seq_len + + self.add_noise_scale_embedding = config.add_noise_scale_embedding + self.noise_scale_max_value = config.noise_scale_max_value + self.time_schedule = config.time_schedule + self.time_shift_type = config.time_shift_type + self.base_shift = config.base_shift + self.max_shift = config.max_shift + self.base_image_seq_len = config.base_image_seq_len + self.max_image_seq_len = config.max_image_seq_len + + if self.add_noise_scale_embedding: + noise_scale_embedder = TimestepEmbedder(llm_hidden_size) + self.fm_modules['noise_scale_embedder'] = noise_scale_embedder + + + + self.img_context_token_id = None + self.img_start_token_id = 151670 + # self.conv_template = get_conv_template(self.template) + # self.system_message = self.conv_template.system_message + + + def extract_feature(self, pixel_values, gen_model=False, grid_hw=None): + if gen_model: + return self.fm_modules['vision_model_mot_gen'](pixel_values=pixel_values, + output_hidden_states=False, + return_dict=True, + grid_hw=grid_hw).last_hidden_state + else: + return self.vision_model(pixel_values=pixel_values, + output_hidden_states=False, + return_dict=True, + grid_hw=grid_hw).last_hidden_state + + def patchify(self, images, patch_size, channel_first=False): + """ + images: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + h, w = images.shape[2] // patch_size, images.shape[3] // patch_size + x = images.reshape(shape=(images.shape[0], 3, h, patch_size, w, patch_size)) + + if channel_first: + x = torch.einsum('nchpwq->nhwcpq', x) + else: + x = torch.einsum('nchpwq->nhwpqc', x) + + x = x.reshape(shape=(images.shape[0], h * w, patch_size**2 * 3)) + return x + + def unpatchify(sle, x, patch_size, h=None, w=None): + """ + x: (N, L, patch_size**2 *3) + images: (N, 3, H, W) + """ + if h is None or w is None: + h = w = int(x.shape[1]**.5) + else: + h = h // patch_size + w = w // patch_size + x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + images = x.reshape(shape=(x.shape[0], 3, h * patch_size, w * patch_size)) + return images + + def _euler_step(self, v_pred, z, t, t_next): + z_next = z + (t_next - t) * v_pred + return z_next + + def _calculate_dynamic_mu(self, image_seq_len: int) -> float: + denom = self.max_image_seq_len - self.base_image_seq_len + if denom == 0: + return float(self.base_shift) + m = (self.max_shift - self.base_shift) / denom + b = self.base_shift - m * self.base_image_seq_len + return float(image_seq_len) * m + b + + def _apply_time_schedule(self, t: torch.Tensor, image_seq_len: int, timestep_shift: float) -> torch.Tensor: + self.time_schedule = "standard" + sigma = 1 - t + if timestep_shift != 1: + self.time_schedule = "standard" + if self.time_schedule == "standard": + shift = timestep_shift + sigma = shift * sigma / (1 + (shift - 1) * sigma) + elif self.time_schedule == "dynamic": + mu = self._calculate_dynamic_mu(image_seq_len) + mu_t = t.new_tensor(mu) + if self.time_shift_type == "exponential": + shift = torch.exp(mu_t) + sigma = shift * sigma / (1 + (shift - 1) * sigma) + elif self.time_shift_type == "linear": + sigma = mu_t / (mu_t + (1 / sigma - 1)) + else: + raise ValueError(f"Unsupported time_shift_type: {self.time_shift_type}") + else: + raise ValueError(f"Unsupported time_schedule: {self.time_schedule}") + return 1 - sigma + + def _build_t2i_image_indexes(self, token_h, token_w, text_len, device): + t_image = torch.full((token_h * token_w,), text_len, dtype=torch.long, device=device) + idx = torch.arange(token_h * token_w, device=device, dtype=torch.long) + h_image = idx // token_w + w_image = idx % token_w + return torch.stack([t_image, h_image, w_image], dim=0) + + + + def _t2i_predict_v(self, input_embeds, indexes_image, attn_mask, past_key_values, t, z, + image_token_num, timestep_embeddings=None, image_size=None): + B, L = z.shape[0], z.shape[1] + + outputs = self.language_model.model( + inputs_embeds=input_embeds, + image_gen_indicators=torch.ones((input_embeds.shape[0], input_embeds.shape[1]), dtype=torch.bool, device=input_embeds.device), + indexes=indexes_image, + attention_mask=attn_mask, + past_key_values=past_key_values, + update_cache=False, + use_cache=True, + ) + + if self.use_pixel_head: + merge_size = int(1 / self.downsample_ratio) + token_h = image_size[1] // (self.patch_size * merge_size) + token_w = image_size[0] // (self.patch_size * merge_size) + + img_reshaped = outputs.last_hidden_state[:, -image_token_num:].view(B, token_h, token_w, -1) + img_2d = torch.einsum("b h w c -> b c h w", img_reshaped) + img_2d = img_2d.contiguous().view(B, -1, token_h, token_w) + + smoothed_img_2d = self.fm_modules['fm_head'](img_2d) + + smoothed_reshaped = smoothed_img_2d.view(B, 3, token_h, self.patch_size * merge_size, token_w, self.patch_size * merge_size) + smoothed_reshaped = torch.einsum("b c h p w q -> b h w p q c", smoothed_reshaped) + out_1d = smoothed_reshaped.contiguous().view(B, L, self.patch_size * merge_size * self.patch_size * merge_size * 3) + x_pred = out_1d + else: + if self.use_deep_fm_head: + x_pred = self.fm_modules["fm_head"]( + outputs.last_hidden_state[:, -image_token_num:].view(B*L, -1), t.repeat(B*L) + ).view(B, L, -1) + else: + x_pred = self.fm_modules["fm_head"]( + outputs.last_hidden_state[:, -image_token_num:].view(B, L, -1) + ).view(B, L, -1) + + + v_pred = (x_pred - z) / (1 - t).clamp_min(self.config.t_eps) + return v_pred + + + @torch.no_grad() + def it2i_generate(self, + past_key_values_condition, + past_key_values_text_uncondition, + past_key_values_img_uncondition, + text_lens, + cfg_scale=1, + img_cfg_scale=1, + cfg_norm='none', + enable_timestep_shift=True, + timestep_shift=1, + image_size=(256, 256), + num_steps=30, + cfg_interval=(0.1, 1.0), + batch_size=1, + t_eps=0.02, + ): + + self.config.t_eps = t_eps + device, dtype = self.get_cache_device_dtype(past_key_values_condition) + S1, S2, S3 = text_lens + + merge_size = int(1 / self.downsample_ratio) + + token_h = image_size[1] // (self.patch_size * merge_size) + token_w = image_size[0] // (self.patch_size * merge_size) + + indexes_image_condition = self._build_t2i_image_indexes(token_h, token_w, S1, device=device) + indexes_image_text_uncondition = self._build_t2i_image_indexes(token_h, token_w, S2, device=device) + indexes_image_img_uncondition = self._build_t2i_image_indexes(token_h, token_w, S3, device=device) + + for layer_idx in range(len(past_key_values_condition.layers)): + past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:]) + past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[layer_idx].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:]) + past_key_values_text_uncondition.layers[layer_idx].keys = past_key_values_text_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].keys.shape[1:]) + past_key_values_text_uncondition.layers[layer_idx].values = past_key_values_text_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].values.shape[1:]) + past_key_values_img_uncondition.layers[layer_idx].keys = past_key_values_img_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].keys.shape[1:]) + past_key_values_img_uncondition.layers[layer_idx].values = past_key_values_img_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].values.shape[1:]) + + prepare_flash_kv_cache( + past_key_values_condition, + current_len=token_h * token_w, + batch_size=batch_size, + ) + prepare_flash_kv_cache( + past_key_values_text_uncondition, + current_len=token_h * token_w, + batch_size=batch_size, + ) + prepare_flash_kv_cache( + past_key_values_img_uncondition, + current_len=token_h * token_w, + batch_size=batch_size, + ) + + + # init noise image tokens + grid_h = image_size[1] // self.patch_size + grid_w = image_size[0] // self.patch_size + grid_hw = torch.tensor([[grid_h, grid_w]]*batch_size, device=device) + + noise_scale = self.noise_scale + if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'): + noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2) / self.noise_scale_base_image_seq_len) + base = float(self.noise_scale_base_image_seq_len) + scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base) + noise_scale = scale * float(self.noise_scale) + if self.noise_scale_mode == 'dynamic_sqrt': + noise_scale = math.sqrt(noise_scale) + noise_scale = min(noise_scale, self.noise_scale_max_value) + + image_prediction = noise_scale * torch.randn((batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype) + + attention_mask_condition = {"full_attention": None} + attention_mask_text_uncondition = {"full_attention": None} + attention_mask_img_uncondition = {"full_attention": None} + + timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device) + if enable_timestep_shift: + timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift) + + for step_i in range(num_steps): + t = timesteps[step_i] + t_next = timesteps[step_i + 1] + + z = self.patchify(image_prediction, self.patch_size * merge_size) + image_input = self.patchify(image_prediction, self.patch_size, channel_first=True) + image_embeds = self.extract_feature(image_input.view(batch_size * grid_h*grid_w, -1), gen_model=True, grid_hw=grid_hw).view(batch_size, token_h*token_w, -1) + t_expanded = t.expand(batch_size*token_h*token_w) + timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(batch_size, token_h*token_w, -1) + if self.add_noise_scale_embedding: + noise_scale_tensor = torch.full_like(t_expanded, noise_scale/self.noise_scale_max_value) + noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(batch_size, token_h*token_w, -1) + timestep_embeddings += noise_embeddings + image_embeds = image_embeds + timestep_embeddings + + v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size) + if t > cfg_interval[0] and t < cfg_interval[1]: + if cfg_scale > 1: + v_pred_text_uncondition = self._t2i_predict_v(image_embeds, indexes_image_text_uncondition, attention_mask_text_uncondition, past_key_values_text_uncondition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size) + else: + v_pred_text_uncondition = 0 + if img_cfg_scale > 1: + v_pred_img_uncondition = self._t2i_predict_v(image_embeds, indexes_image_img_uncondition, attention_mask_img_uncondition, past_key_values_img_uncondition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size) + else: + v_pred_img_uncondition = 0 + + if t > cfg_interval[0] and t < cfg_interval[1]: + v_pred_text = v_pred_text_uncondition + cfg_scale * (v_pred_condition - v_pred_text_uncondition) + if cfg_norm == 'text_channel': + norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True) + norm_v_cfg = torch.norm(v_pred_text, dim=-1, keepdim=True) + scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) + v_pred_text = v_pred_text * scale + v_pred = v_pred_img_uncondition + img_cfg_scale * (v_pred_text - v_pred_img_uncondition) + if cfg_norm == 'global': + norm_v_condition = torch.norm(v_pred_condition, dim=(1,2), keepdim=True) + norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True) + scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) + v_pred = v_pred * scale + elif cfg_norm == 'channel': + norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True) + norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True) + scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) + v_pred = v_pred * scale + + else: + v_pred = v_pred_condition + + z = z + (t_next - t) * v_pred + + image_prediction = self.unpatchify(z, self.patch_size * merge_size, image_size[1], image_size[0]) + + clear_flash_kv_cache(past_key_values_condition) + clear_flash_kv_cache(past_key_values_text_uncondition) + clear_flash_kv_cache(past_key_values_img_uncondition) + + return image_prediction + + + def get_cache_device_dtype(self, cache): + """ + Returns (device, dtype) of a DynamicCache. + Assumes all layers share same device/dtype. + """ + for layer in cache.layers: + return layer.device, layer.dtype + raise ValueError("Cache is empty") + + @torch.no_grad() + def t2i_generate(self, + past_key_values_condition, + past_key_values_uncondition, + text_lens, + cfg_scale=1, + timestep_shift=1, + enable_timestep_shift=True, + cfg_norm='none', + image_size=(256, 256), + num_steps=30, + cfg_interval=(0.1, 1.0), + batch_size=1, + t_eps=0.02): + assert cfg_norm in ['cfg_zero_star', 'global', 'none'], f"cfg_norm={cfg_norm}" + merge_size = int(1 / self.downsample_ratio) + self.config.t_eps = t_eps + + token_h = image_size[1] // (self.patch_size * merge_size) + token_w = image_size[0] // (self.patch_size * merge_size) + + device, dtype = self.get_cache_device_dtype(past_key_values_condition) + S1, S2 = text_lens + + indexes_image_condition = self._build_t2i_image_indexes(token_h, token_w, S1, device=device) + indexes_image_uncondition = self._build_t2i_image_indexes(token_h, token_w, S2, device=device) + + for layer_idx in range(len(past_key_values_condition.layers)): + past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:]) + past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[layer_idx].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:]) + past_key_values_uncondition.layers[layer_idx].keys = past_key_values_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].keys.shape[1:]) + past_key_values_uncondition.layers[layer_idx].values = past_key_values_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].values.shape[1:]) + + # prepare flash cache once + prepare_flash_kv_cache( + past_key_values_condition, + current_len=token_h * token_w, + batch_size=batch_size, + ) + prepare_flash_kv_cache( + past_key_values_uncondition, + current_len=token_h * token_w, + batch_size=batch_size, + ) + + # init noise image tokens + grid_h = image_size[1] // self.patch_size + grid_w = image_size[0] // self.patch_size + grid_hw = torch.tensor([[grid_h, grid_w]]*batch_size, device=device) + + noise_scale = self.noise_scale + if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'): + noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2) / self.noise_scale_base_image_seq_len) + base = float(self.noise_scale_base_image_seq_len) + scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base) + noise_scale = scale * float(self.noise_scale) + if self.noise_scale_mode == 'dynamic_sqrt': + noise_scale = math.sqrt(noise_scale) + noise_scale = min(noise_scale, self.noise_scale_max_value) + + image_prediction = noise_scale * torch.randn((batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype) + + attention_mask_condition = {"full_attention": None} + attention_mask_uncondition = {"full_attention": None} + + timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device) + + if enable_timestep_shift: + timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift) + + for step_i in range(num_steps): + t = timesteps[step_i] + t_next = timesteps[step_i + 1] + + z = self.patchify(image_prediction, self.patch_size * merge_size) + image_input = self.patchify(image_prediction, self.patch_size, channel_first=True) + image_embeds = self.extract_feature(image_input.view(batch_size * grid_h*grid_w, -1), gen_model=True, grid_hw=grid_hw).view(batch_size, token_h*token_w, -1) + t_expanded = t.expand(batch_size*token_h*token_w) + timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(batch_size, token_h*token_w, -1) + if self.add_noise_scale_embedding: + noise_scale_tensor = torch.full_like(t_expanded, noise_scale / self.noise_scale_max_value) + noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(batch_size, token_h*token_w, -1) + timestep_embeddings += noise_embeddings + image_embeds = image_embeds + timestep_embeddings + + + v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w, + timestep_embeddings=timestep_embeddings, image_size=image_size) + + + if t > cfg_interval[0] and t < cfg_interval[1] and cfg_scale > 1: + v_pred_uncondition = self._t2i_predict_v(image_embeds, indexes_image_uncondition, attention_mask_uncondition, past_key_values_uncondition, t, z, image_token_num=token_h*token_w, + timestep_embeddings=timestep_embeddings, image_size=image_size) + if cfg_norm == 'cfg_zero_star': + positive_flat = v_pred_condition.view(batch_size, -1) + negative_flat = v_pred_uncondition.view(batch_size, -1) + + alpha = optimized_scale(positive_flat,negative_flat) + alpha = alpha.view(batch_size, *([1] * (len(v_pred_condition.shape) - 1))) + alpha = alpha.to(positive_flat.dtype) + + if (step_i <= 0): + v_pred = v_pred_condition*0. + else: + v_pred = v_pred_uncondition * alpha + cfg_scale * (v_pred_condition - v_pred_uncondition * alpha) + else: + v_pred = v_pred_uncondition + cfg_scale * (v_pred_condition - v_pred_uncondition) + if cfg_norm == 'global': + norm_v_condition = torch.norm(v_pred_condition, dim=(1,2), keepdim=True) + norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True) + scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) + v_pred = v_pred * scale + else: + v_pred = v_pred_condition + + z = z + (t_next - t) * v_pred + + image_prediction = self.unpatchify(z, self.patch_size * merge_size, image_size[1], image_size[0]) + + clear_flash_kv_cache(past_key_values_condition) + clear_flash_kv_cache(past_key_values_uncondition) + + return image_prediction + + + @property + def lm_head(self): + return self.language_model.get_output_embeddings() + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + return self.language_model.set_input_embeddings(value) + + def set_output_embeddings(self, value): + return self.language_model.set_output_embeddings(value) + + def get_thw_indexes(self, input_ids, grid_hw=None): + img_start_shift = torch.cat([torch.zeros(1, dtype=torch.long).to(input_ids.device), + (input_ids == self.img_start_token_id).long()], dim=0)[:-1] + not_img_token = (input_ids != self.img_context_token_id).long() + t_indexes = ((img_start_shift + not_img_token).cumsum(0) - 1) + h_indexes = torch.zeros_like(t_indexes).to(t_indexes.device) + w_indexes = torch.zeros_like(t_indexes).to(t_indexes.device) + + if grid_hw is not None: + selected = (input_ids == self.img_context_token_id) + if selected.long().sum() > 0: + abs_pos_w, abs_pos_h = build_abs_positions_from_grid_hw( + grid_hw // int(1 / self.downsample_ratio), device=t_indexes.device) + h_indexes[selected] = abs_pos_h.to(t_indexes.device, t_indexes.dtype) + w_indexes[selected] = abs_pos_w.to(t_indexes.device, t_indexes.dtype) + return torch.stack([t_indexes, h_indexes, w_indexes], dim=0) + + +NORM_MEAN = [0.5, 0.5, 0.5] +NORM_STD = [0.5, 0.5, 0.5] + +class NEOX2I: + def __init__(self, model_path, device): + self.device = device + self.model: NEOChatModel = NEOChatModel.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ).to(self.device) + + self.model.eval() + + def _denorm(self, x: torch.Tensor, mean=NORM_MEAN, std=NORM_STD): + """ + x: [B,3,H,W] normalized ((img-mean)/std). returns [0,1] clamped. + """ + mean = torch.tensor(mean, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + std = torch.tensor(std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + return (x * std + mean).clamp(0, 1) + + def _get_dynamic_cache(self, past_kv): + """ + past_kv ( L, 2, H // 2, P * S, D) + """ + past_kv_dc = DynamicCache(config=self.model.language_model.model.config) + L, _, H, S, D = past_kv.shape + for layer_idx in range(L): + k = past_kv[layer_idx][0].unsqueeze(0).to(self.device, non_blocking=True) + v = past_kv[layer_idx][1].unsqueeze(0).to(self.device, non_blocking=True) + past_kv_dc.update(key_states=k, value_states=v, layer_idx=layer_idx,) + return past_kv_dc + + + def t2i(self, past_kv, past_kv_txt, param: X2IParams): + past_kv_dc = self._get_dynamic_cache(past_kv) + past_kv_txt_dc = self._get_dynamic_cache(past_kv_txt) + text_lens = (param.past_kvcache.get_compressed_len(), + param.past_kvcache_text.get_compressed_len()) + output = self.model.t2i_generate( + past_key_values_condition=past_kv_dc, + past_key_values_uncondition=past_kv_txt_dc, + text_lens=text_lens, + cfg_norm=param.get_cfg_norm(), + cfg_scale=param.guidance_scale, + image_size=(param.width, param.height), + num_steps=param.steps, + batch_size=param.num_images) + + return self._post_process(output) + + def _post_process(self, output): + images = self._denorm(output) + images = (images.clamp(0, 1) * 255.0).round().to(torch.uint8).cpu() + + base64_images = [ + base64.b64encode(io.encode_jpeg(img).numpy()).decode("utf-8") + for img in images + ] + return base64_images + + def it2i(self, past_kv, past_kv_txt, past_kv_img, param: X2IParams): + past_kv_dc = self._get_dynamic_cache(past_kv) + past_kv_txt_dc = self._get_dynamic_cache(past_kv_txt) + past_kv_img_dc = self._get_dynamic_cache(past_kv_img) + text_lens = (param.past_kvcache.get_compressed_len(), + param.past_kvcache_text.get_compressed_len(), + param.past_kvcache_img.get_compressed_len()) + output = self.model.it2i_generate( + past_key_values_condition=past_kv_dc, + past_key_values_text_uncondition=past_kv_txt_dc, + past_key_values_img_uncondition=past_kv_img_dc, + text_lens=text_lens, + cfg_norm=param.get_cfg_norm(), + cfg_scale=param.guidance_scale, + img_cfg_scale=param.image_guidance_scale, + image_size=(param.width, param.height), + num_steps=param.steps, + batch_size=param.num_images, + ) + + return self._post_process(output) diff --git a/lightllm/server/x2i_server/naive/modeling_neo_vit.py b/lightllm/server/x2i_server/naive/modeling_neo_vit.py new file mode 100644 index 0000000000..63ded83e7d --- /dev/null +++ b/lightllm/server/x2i_server/naive/modeling_neo_vit.py @@ -0,0 +1,235 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.modeling_outputs import BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel + +from .configuration_neo_vit import NEOVisionConfig + + +def precompute_rope_freqs_sincos( + dim: int, max_position: int, base: float = 10000.0, device=None +): + """预计算 RoPE 的 cos 和 sin 值 (1D)。""" + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)) + t = torch.arange(max_position, device=device).type_as(inv_freq) + freqs = torch.outer(t, inv_freq) + return torch.cos(freqs), torch.sin(freqs) + + +def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): + """ + Compute patch coordinates (x, y) + + Args: + grid_hw: (B, 2) tensor representing (H, W) per image + """ + device = grid_hw.device + B = grid_hw.shape[0] + + # Get the number of patches per image + H = grid_hw[:, 0] + W = grid_hw[:, 1] + N = H * W + N_total = N.sum() + + # Create the batch index for each patch (B x patch count) + patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,) + + # Generate intra-image patch index (row-major order) + patch_id_within_image = torch.arange(N_total, device=device) + patch_id_within_image = patch_id_within_image - torch.cumsum( + torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0 + )[patch_to_sample] + + # Get H/W for each patch according to its image + W_per_patch = W[patch_to_sample] + abs_x = patch_id_within_image % W_per_patch + abs_y = patch_id_within_image // W_per_patch + + return abs_x, abs_y + + +def apply_rotary_emb_1d( + x: torch.Tensor, + cos_cached: torch.Tensor, + sin_cached: torch.Tensor, + positions: torch.Tensor, +): + """对输入张量的一部分应用1D RoPE。""" + # x: (..., seq_len, dim_part) + # positions: (..., seq_len) + # cos_cached: (max_pos, dim_part / 2) + + cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) + sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) + + x1 = x[..., 0::2] + x2 = x[..., 1::2] + + rotated_x1 = x1 * cos - x2 * sin + rotated_x2 = x1 * sin + x2 * cos + + x_rotated = torch.empty_like(x) + x_rotated[..., 0::2] = rotated_x1 + x_rotated[..., 1::2] = rotated_x2 + return x_rotated + + +def apply_2d_rotary_pos_emb( + x: torch.Tensor, + cos_cached_x: torch.Tensor, + sin_cached_x: torch.Tensor, + cos_cached_y: torch.Tensor, + sin_cached_y: torch.Tensor, + abs_positions_x: torch.Tensor, + abs_positions_y: torch.Tensor +): + """应用2D RoPE到输入张量x。""" + dim = x.shape[-1] + dim_half = dim // 2 + + # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向 + # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致) + x_part_1 = x[..., :dim_half] + x_part_2 = x[..., dim_half:] + + # 将与 abs_positions_x 相关的旋转应用于 x_part_1 + rotated_part_1 = apply_rotary_emb_1d( + x_part_1, cos_cached_x, sin_cached_x, abs_positions_x + ) + # 将与 abs_positions_y 相关的旋转应用于 x_part_2 + rotated_part_2 = apply_rotary_emb_1d( + x_part_2, cos_cached_y, sin_cached_y, abs_positions_y + ) + + # 将它们重新拼接起来。确保顺序与你分割时一致。 + return torch.cat((rotated_part_1, rotated_part_2), dim=-1) + + +class NEOVisionEmbeddings(nn.Module): + """ + Embedding Module for Vision. + """ + + def __init__(self, config: NEOVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.llm_embed_dim = config.llm_hidden_size[0] + self.downsample_factor = int(1 / config.downsample_ratio[0]) + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + ) + self.dense_embedding = nn.Conv2d( + in_channels=self.embed_dim, out_channels=self.llm_embed_dim, kernel_size=self.downsample_factor, stride=self.downsample_factor + ) + self.gelu = nn.GELU() + + self.rope_dim_part = self.embed_dim // 2 + cos_x, sin_x = precompute_rope_freqs_sincos( + self.rope_dim_part, config.max_position_embeddings_vision, base=config.rope_theta_vision, device=None + ) + cos_y, sin_y = precompute_rope_freqs_sincos( + self.rope_dim_part, config.max_position_embeddings_vision, base=config.rope_theta_vision, device=None + ) + + self.register_buffer("cos_cached_x", cos_x, persistent=False) + self.register_buffer("sin_cached_x", sin_x, persistent=False) + self.register_buffer("cos_cached_y", cos_y, persistent=False) + self.register_buffer("sin_cached_y", sin_y, persistent=False) + + def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): + """ + Apply 2D Rotary Position Embedding to the patch embeddings. + """ + abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) + embeddings = apply_2d_rotary_pos_emb( + patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 + self.cos_cached_x, self.sin_cached_x, + self.cos_cached_y, self.sin_cached_y, + abs_pos_x, + abs_pos_y + ).to(self.patch_embedding.weight.dtype) + return embeddings + + def forward(self, pixel_values: torch.FloatTensor, grid_hw=None) -> torch.Tensor: + + pixel_values = pixel_values.view( # + -1, + 3, + self.patch_size, + self.patch_size, + ) # [28072, 768] -> [28072, 3, 16, 16] + patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) + self.cos_cached_x = self.cos_cached_x.to(patch_embeds.device) + self.sin_cached_x = self.sin_cached_x.to(patch_embeds.device) + self.cos_cached_y = self.cos_cached_y.to(patch_embeds.device) + self.sin_cached_y = self.sin_cached_y.to(patch_embeds.device) + patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) # [28072, 1024] + assert (grid_hw[:,0] * grid_hw[:,1]).sum() == patch_embeds.shape[0] + + patches_list = [] + cur_position = 0 + for i in range(grid_hw.shape[0]): + h, w = grid_hw[i] + patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0) + patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2)) + patches_per_img = patches_per_img.permute(0, 2, 3, 1) + patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1])) + cur_position += h * w + + embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) + + assert cur_position == patch_embeds.shape[0] + assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor**2) + + return embeddings + + +class NEOVisionModel(PreTrainedModel): + main_input_name = 'pixel_values' + _supports_flash_attn_2 = True + supports_gradient_checkpointing = True + config_class = NEOVisionConfig + # support transformers 4.51.+ + _tp_plan = '' + + def __init__(self, config: NEOVisionConfig): + super().__init__(config) + self.config = config + + self.embeddings = NEOVisionEmbeddings(config) + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_embeds: Optional[torch.FloatTensor] = None, + grid_hw: Optional[torch.Tensor] = None + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None and pixel_embeds is None: + raise ValueError('You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + else: + assert pixel_values.dim() == 2, f"pixel_values must be 2D for native resolution, got: {pixel_values.dim()}" + hidden_states = self.embeddings(pixel_values, grid_hw=grid_hw) + + return BaseModelOutputWithPooling( + last_hidden_state=hidden_states, + pooler_output=None, + hidden_states=None, + attentions=None, + ) diff --git a/lightllm/server/x2i_server/naive/modeling_qwen3.py b/lightllm/server/x2i_server/naive/modeling_qwen3.py new file mode 100644 index 0000000000..5a4c5acc40 --- /dev/null +++ b/lightllm/server/x2i_server/naive/modeling_qwen3.py @@ -0,0 +1,1117 @@ +from typing import Callable, Optional, Union + +import torch +import torch._dynamo +from torch import nn + +import copy +import math +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.generation import GenerationMixin +from transformers.integrations import use_kernel_forward_from_hub +from transformers.masking_utils import create_causal_mask +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_layers import ( + GenericForQuestionAnswering, + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import check_model_inputs +from transformers import Qwen3Config + +from flash_attn import flash_attn_func + + +def create_block_causal_mask(index: torch.Tensor): + """ + index: (L) + return: (1, 1, L, L) block-wise causal attention mask + """ + L = index.size(0) + idx_i = index.unsqueeze(1).expand(L, L) + idx_j = index.unsqueeze(0).expand(L, L) + + arange = torch.arange(L, device=index.device) + mask = (idx_j == idx_i) | (arange.unsqueeze(0) <= arange.unsqueeze(1)) + + return torch.where(mask[None, None, :, :] > 0, torch.tensor(0.0), torch.tensor(float('-inf'))) + + +def visualize_mask(mask: torch.Tensor, i: int = 0, j: int = 12): + """ + mask: (1,1, L, L) + """ + submask = torch.where(mask[0, 0, :, :] == 0, torch.tensor(1.0), torch.tensor(0.0)) + submask = mask[i:j, i:j].int().cpu().numpy() + for row in submask: + print(" ".join(map(str, row))) + + +@use_kernel_forward_from_hub("RMSNorm") +class Qwen3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + Qwen3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen3MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class Qwen3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: Qwen3Config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + base_rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + def _rope_init_fn_keep_freq_range(cfg: Qwen3Config, dev=None): + inv_freq, attention_scaling = base_rope_init_fn(cfg, dev) + + cfg2 = copy.deepcopy(cfg) + head_dim = getattr(cfg2, "head_dim", None) + if head_dim is None: + head_dim = cfg2.hidden_size // cfg2.num_attention_heads + setattr(cfg2, "head_dim", head_dim) + cfg2.head_dim = int(head_dim) * 2 + + inv_freq_full, _ = base_rope_init_fn(cfg2, dev) + inv_freq = inv_freq_full[::2] + + return inv_freq, attention_scaling + + self.rope_init_fn = _rope_init_fn_keep_freq_range + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + self.q_proj_mot_gen = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias + ) + + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj_mot_gen = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.v_proj_mot_gen = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + self.o_proj_mot_gen = nn.Linear( + config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias + ) + + self.q_norm = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) # unlike olmo, only on the head dim! + self.q_norm_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) + self.q_norm_hw = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) + self.q_norm_hw_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) + + self.k_norm = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.k_norm_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) + self.k_norm_hw = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) # thus post q_norm does not need reshape + self.k_norm_hw_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) + + self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + + t_config = copy.deepcopy(config) + t_config.head_dim = config.head_dim // 2 + self.rotary_emb = Qwen3RotaryEmbedding(config=t_config) + + hw_config = copy.deepcopy(config) + hw_config.head_dim = config.head_dim // 4 + hw_config.rope_theta = config.rope_theta_hw + hw_config.max_position_embeddings = config.max_position_embeddings_hw + self.rotary_emb_hw = Qwen3RotaryEmbedding(config=hw_config) + + def forward_und( + self, + hidden_states: torch.Tensor, + indexes: Optional[torch.LongTensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert self.config._attn_implementation == "eager" + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states_t, query_states_hw = query_states.chunk(2, dim=-1) + query_states_t = self.q_norm(query_states_t).transpose(1, 2) + query_states_hw = self.q_norm_hw(query_states_hw).transpose(1, 2) + query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1) + + key_states = self.k_proj(hidden_states).view(hidden_shape) + key_states_t, key_states_hw = key_states.chunk(2, dim=-1) + key_states_t = self.k_norm(key_states_t).transpose(1, 2) + key_states_hw = self.k_norm_hw(key_states_hw).transpose(1, 2) + key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1) + + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0)) + query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t) + + cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0)) + query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h) + + cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0)) + query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w) + + query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1) + key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1) + + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + update_cache = kwargs.get("update_cache", True) + if update_cache: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs=None) + else: + # only use the past key values but do not append the current one + layer = past_key_values.layers[self.layer_idx] + past_k, past_v = layer.keys, layer.values + + if past_k is not None: + key_states = torch.cat([past_k, key_states], dim=2) # concat on seq_len + value_states = torch.cat([past_v, value_states], dim=2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + # def forward_gen( + # self, + # hidden_states: torch.Tensor, + # indexes: Optional[torch.LongTensor], + # attention_mask: Optional[torch.Tensor], + # past_key_values: Optional[Cache] = None, + # cache_position: Optional[torch.LongTensor] = None, + # **kwargs: Unpack[FlashAttentionKwargs], + # ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + # assert self.config._attn_implementation == "eager" + # input_shape = hidden_states.shape[:-1] + # hidden_shape = (*input_shape, -1, self.head_dim) + + # query_states = self.q_proj_mot_gen(hidden_states).view(hidden_shape) + # query_states_t, query_states_hw = query_states.chunk(2, dim=-1) + # query_states_t = self.q_norm_mot_gen(query_states_t).transpose(1, 2) + # query_states_hw = self.q_norm_hw_mot_gen(query_states_hw).transpose(1, 2) + # query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1) + + # key_states = self.k_proj_mot_gen(hidden_states).view(hidden_shape) + # key_states_t, key_states_hw = key_states.chunk(2, dim=-1) + # key_states_t = self.k_norm_mot_gen(key_states_t).transpose(1, 2) + # key_states_hw = self.k_norm_hw_mot_gen(key_states_hw).transpose(1, 2) + # key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1) + + # value_states = self.v_proj_mot_gen(hidden_states).view(hidden_shape).transpose(1, 2) + + # cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0)) + # query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t) + + # cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0)) + # query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h) + + # cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0)) + # query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w) + + # query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1) + # key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1) + + + # if past_key_values is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # # key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + # update_cache = kwargs.get("update_cache", True) + # if update_cache: + # key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs=None) + # else: + # # only use the past key values but do not append the current one + # layer = past_key_values.layers[self.layer_idx] + # past_k, past_v = layer.keys, layer.values + + # if past_k is not None: + # key_states = torch.cat([past_k, key_states], dim=2) # concat on seq_len + # value_states = torch.cat([past_v, value_states], dim=2) + + # attention_interface: Callable = eager_attention_forward + # if self.config._attn_implementation != "eager": + # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + # attn_output, attn_weights = attention_interface( + # self, + # query_states, + # key_states, + # value_states, + # attention_mask, + # dropout=0.0 if not self.training else self.attention_dropout, + # scaling=self.scaling, + # sliding_window=self.sliding_window, # diff with Llama + # **kwargs, + # ) + + # attn_output = attn_output.reshape(*input_shape, -1).contiguous() + # attn_output = self.o_proj_mot_gen(attn_output) + # return attn_output, attn_weights + + def forward_gen( + self, + hidden_states: torch.Tensor, + indexes: Optional[torch.LongTensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + # ----------------------------- + # Build q / k / v for current tokens + # Internal layout before flash: + # q/k/v: [B, H, S, D] + # Flash layout: + # q/k/v: [B, S, H, D] + # ----------------------------- + query_states = self.q_proj_mot_gen(hidden_states).view(hidden_shape) + query_states_t, query_states_hw = query_states.chunk(2, dim=-1) + query_states_t = self.q_norm_mot_gen(query_states_t).transpose(1, 2) # [B,H,S,D/2] + query_states_hw = self.q_norm_hw_mot_gen(query_states_hw).transpose(1, 2) + query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1) + + key_states = self.k_proj_mot_gen(hidden_states).view(hidden_shape) + key_states_t, key_states_hw = key_states.chunk(2, dim=-1) + key_states_t = self.k_norm_mot_gen(key_states_t).transpose(1, 2) # [B,H,S,D/2] + key_states_hw = self.k_norm_hw_mot_gen(key_states_hw).transpose(1, 2) + key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1) + + value_states = self.v_proj_mot_gen(hidden_states).view(hidden_shape).transpose(1, 2) # [B,H,S,D] + + # RoPE + cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0)) + query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t) + + cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0)) + query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h) + + cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0)) + query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w) + + # concat along head_dim + # query/key current layout: [B, H, S, D] + query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1) + key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1) + + update_cache = kwargs.get("update_cache", True) + + # ------------------------------------------------------------------ + # Flash path: + # Only use when there is no explicit dense mask. + # This is exactly the t2i denoising use case: + # current image tokens attend to [prefix + current image tokens] + # fully bidirectional inside current block => causal=False + # ------------------------------------------------------------------ + if attention_mask is None: + # Convert current q/k/v to flash layout [B, S, H, D] + q = query_states.transpose(1, 2).contiguous() + k_cur = key_states.transpose(1, 2).contiguous() + v_cur = value_states.transpose(1, 2).contiguous() + + if past_key_values is not None: + if update_cache: + # Rare path, keep compatibility. + # past_key_values.update expects [B,H,S,D] + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs=None + ) + k = key_states.transpose(1, 2).contiguous() + v = value_states.transpose(1, 2).contiguous() + else: + # Optimized path: + # use preallocated flash_k_cache / flash_v_cache + layer = past_key_values.layers[self.layer_idx] + + if ( + hasattr(layer, "flash_k_cache") + and layer.flash_k_cache is not None + and hasattr(layer, "flash_v_cache") + and layer.flash_v_cache is not None + ): + prefix_len = layer.flash_prefix_len + cur_len = k_cur.shape[1] + + # overwrite current segment in-place + layer.flash_k_cache[:, prefix_len:prefix_len + cur_len].copy_(k_cur) + layer.flash_v_cache[:, prefix_len:prefix_len + cur_len].copy_(v_cur) + + k = layer.flash_k_cache[:, :prefix_len + cur_len] + v = layer.flash_v_cache[:, :prefix_len + cur_len] + else: + # fallback if user forgot to prepare flash cache + layer = past_key_values.layers[self.layer_idx] + past_k, past_v = layer.keys, layer.values + + if past_k is not None: + past_k = past_k.transpose(1, 2).contiguous() + past_v = past_v.transpose(1, 2).contiguous() + k = torch.cat([past_k, k_cur], dim=1) + v = torch.cat([past_v, v_cur], dim=1) + else: + k = k_cur + v = v_cur + else: + k = k_cur + v = v_cur + + # sanity checks + assert q.ndim == 4 and k.ndim == 4 and v.ndim == 4 + assert q.shape[0] == k.shape[0] == v.shape[0], (q.shape, k.shape, v.shape) + assert k.shape[1] == v.shape[1], (k.shape, v.shape) + assert k.shape[2] == v.shape[2], (k.shape, v.shape) + assert q.shape[3] == k.shape[3] == v.shape[3], (q.shape, k.shape, v.shape) + + attn_output = flash_attn_func( + q, + k, + v, + dropout_p=0.0 if not self.training else self.attention_dropout, + softmax_scale=self.scaling, + causal=False, + ) # [B, S_q, H_q, D] + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj_mot_gen(attn_output) + return attn_output, None + + # ------------------------------------------------------------------ + # Original eager fallback path + # ------------------------------------------------------------------ + if past_key_values is not None: + if update_cache: + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs=None + ) + else: + layer = past_key_values.layers[self.layer_idx] + past_k, past_v = layer.keys, layer.values + if past_k is not None: + key_states = torch.cat([past_k, key_states], dim=2) + value_states = torch.cat([past_v, value_states], dim=2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj_mot_gen(attn_output) + return attn_output, attn_weights + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + image_gen_indicators: torch.Tensor, + exist_non_image_gen_tokens: bool, + exist_image_gen_tokens: bool, + indexes: Optional[torch.LongTensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + if exist_non_image_gen_tokens and not exist_image_gen_tokens: + return self.forward_und(hidden_states, indexes, attention_mask, past_key_values, cache_position, **kwargs) + if not exist_non_image_gen_tokens and exist_image_gen_tokens: + return self.forward_gen(hidden_states, indexes, attention_mask, past_key_values, cache_position, **kwargs) + + assert self.config._attn_implementation == "eager" + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = hidden_states.new_zeros((*input_shape, self.config.num_attention_heads*self.head_dim)) + if exist_non_image_gen_tokens: + query_states[~image_gen_indicators] = self.q_proj(hidden_states[~image_gen_indicators]) + if exist_image_gen_tokens: + query_states[image_gen_indicators] = self.q_proj_mot_gen(hidden_states[image_gen_indicators]) + query_states_t, query_states_hw = query_states.chunk(2, dim=-1) + + _query_states_hw = query_states_hw.new_zeros(query_states_hw.shape) + if exist_non_image_gen_tokens: + _query_states_hw[~image_gen_indicators] = self.q_norm_hw(query_states_hw[~image_gen_indicators]) + if exist_image_gen_tokens: + _query_states_hw[image_gen_indicators] = self.q_norm_hw_mot_gen(query_states_h[image_gen_indicators]) + query_states_hw = _query_states_hw.transpose(1, 2) + query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1) + + key_states = hidden_states.new_zeros((*input_shape, self.config.num_key_value_heads*self.head_dim)) + if exist_non_image_gen_tokens: + key_states[~image_gen_indicators] = self.k_proj(hidden_states[~image_gen_indicators]) + if exist_image_gen_tokens: + key_states[image_gen_indicators] = self.k_proj_mot_gen(hidden_states[image_gen_indicators]) + key_states_t, key_states_hw = key_states.chunk(2, dim=-1) + + _key_states_hw = key_states_hw.new_zeros(key_states_hw.shape) + if exist_non_image_gen_tokens: + _key_states_hw[~image_gen_indicators] = self.k_norm_hw(key_states_hw[~image_gen_indicators]) + if exist_image_gen_tokens: + _key_states_hw[image_gen_indicators] = self.k_norm_hw_mot_gen(key_states_h[image_gen_indicators]) + key_states_hw = _key_states_hw.transpose(1, 2) + key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1) + + value_states = hidden_states.new_zeros((*input_shape, self.config.num_key_value_heads*self.head_dim)) + if exist_non_image_gen_tokens: + value_states[~image_gen_indicators] = self.v_proj(hidden_states[~image_gen_indicators]) + if exist_image_gen_tokens: + value_states[image_gen_indicators] = self.v_proj_mot_gen(hidden_states[image_gen_indicators]) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0)) + query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t) + + cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0)) + query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h) + + cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0)) + query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w) + + query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1) + key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1) + + + if past_key_values is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + update_cache = kwargs.get("update_cache", True) + if update_cache: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs=None) + else: + # only use the past key values but do not append the current one + layer = past_key_values.layers[self.layer_idx] + past_k, past_v = layer.keys, layer.values + + if past_k is not None: + key_states = torch.cat([past_k, key_states], dim=2) # concat on seq_len + value_states = torch.cat([past_v, value_states], dim=2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + + _attn_output = attn_output.new_zeros((*input_shape, self.config.hidden_size)) + if exist_non_image_gen_tokens: + _attn_output[~image_gen_indicators] = self.o_proj(attn_output[~image_gen_indicators]) + if exist_image_gen_tokens: + _attn_output[image_gen_indicators] = self.o_proj_mot_gen(attn_output[image_gen_indicators]) + + attn_output = _attn_output + return attn_output, attn_weights + + +class Qwen3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: Qwen3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) + + self.mlp = Qwen3MLP(config) + self.mlp_mot_gen = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attention_type = config.layer_types[layer_idx] + + def forward_und( + self, + hidden_states: torch.Tensor, + image_gen_indicators: torch.Tensor, + exist_non_image_gen_tokens: bool, + exist_image_gen_tokens: bool, + indexes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + image_gen_indicators=image_gen_indicators, + exist_non_image_gen_tokens=exist_non_image_gen_tokens, + exist_image_gen_tokens=exist_image_gen_tokens, + indexes=indexes, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + def forward_gen( + self, + hidden_states: torch.Tensor, + image_gen_indicators: torch.Tensor, + exist_non_image_gen_tokens: bool, + exist_image_gen_tokens: bool, + indexes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm_mot_gen(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + image_gen_indicators=image_gen_indicators, + exist_non_image_gen_tokens=exist_non_image_gen_tokens, + exist_image_gen_tokens=exist_image_gen_tokens, + indexes=indexes, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm_mot_gen(hidden_states) + hidden_states = self.mlp_mot_gen(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") + def forward( + self, + hidden_states: torch.Tensor, + image_gen_indicators: torch.Tensor, + exist_non_image_gen_tokens: bool, + exist_image_gen_tokens: bool, + indexes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + if exist_non_image_gen_tokens and not exist_image_gen_tokens: + return self.forward_und(hidden_states, image_gen_indicators, exist_non_image_gen_tokens, exist_image_gen_tokens, indexes, attention_mask, position_ids, past_key_values, use_cache, cache_position, **kwargs) + if not exist_non_image_gen_tokens and exist_image_gen_tokens: + return self.forward_gen(hidden_states, image_gen_indicators, exist_non_image_gen_tokens, exist_image_gen_tokens, indexes, attention_mask, position_ids, past_key_values, use_cache, cache_position, **kwargs) + + residual = hidden_states + + _hidden_states = hidden_states.new_zeros(hidden_states.shape) + if exist_non_image_gen_tokens: + _hidden_states[~image_gen_indicators] = self.input_layernorm(hidden_states[~image_gen_indicators]) + if exist_image_gen_tokens: + _hidden_states[image_gen_indicators] = self.input_layernorm_mot_gen(hidden_states[image_gen_indicators]) + hidden_states = _hidden_states + + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + image_gen_indicators=image_gen_indicators, + exist_non_image_gen_tokens=exist_non_image_gen_tokens, + exist_image_gen_tokens=exist_image_gen_tokens, + indexes=indexes, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + _hidden_states = hidden_states.new_zeros(hidden_states.shape) + if exist_non_image_gen_tokens: + _hidden_states[~image_gen_indicators] = self.mlp(self.post_attention_layernorm(hidden_states[~image_gen_indicators])) + + if exist_image_gen_tokens: + _hidden_states[image_gen_indicators] = self.mlp_mot_gen(self.post_attention_layernorm_mot_gen(hidden_states[image_gen_indicators])) + + hidden_states = _hidden_states + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class Qwen3PreTrainedModel(PreTrainedModel): + config: Qwen3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + + _can_compile_fullgraph = True + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": Qwen3DecoderLayer, + "attentions": Qwen3Attention, + } + + +@auto_docstring +class Qwen3Model(Qwen3PreTrainedModel): + def __init__(self, config: Qwen3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + self.current_index = -1 + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + image_gen_indicators: Optional[torch.Tensor] = None, + indexes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + + # assert position_ids is not None + # assert cache_position is not None + # assert past_key_values is not None + + if image_gen_indicators is None: + exist_non_image_gen_tokens = True + exist_image_gen_tokens = False + else: + exist_non_image_gen_tokens = (~image_gen_indicators).any() + exist_image_gen_tokens = image_gen_indicators.any() + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + if input_ids is not None: + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + self.current_index += 1 + indexes = torch.LongTensor([[self.current_index], [0], [0]]).to(input_ids.device) + else: + causal_mask_mapping = { + "full_attention": create_block_causal_mask(indexes[0]), + } + self.current_index = indexes[0].max() + else: + self.current_index = indexes[0].max() + # raise NotImplementedError('not isinstance(causal_mask_mapping := attention_mask, dict)') + + # The sliding window alternating layers are not always activated depending on the config + # if self.has_sliding_layers: + # causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs) + + hidden_states = inputs_embeds + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + image_gen_indicators=image_gen_indicators, + exist_non_image_gen_tokens=exist_non_image_gen_tokens, + exist_image_gen_tokens=exist_image_gen_tokens, + indexes=indexes, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + if not exist_image_gen_tokens: + hidden_states = self.norm(hidden_states) + elif not exist_non_image_gen_tokens: + hidden_states = self.norm_mot_gen(hidden_states) + else: + _hidden_states = hidden_states.new_zeros(hidden_states.shape) + _hidden_states[~image_gen_indicators] = self.norm(hidden_states[~image_gen_indicators]) + _hidden_states[image_gen_indicators] = self.norm_mot_gen(hidden_states[image_gen_indicators]) + hidden_states = _hidden_states + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +@auto_docstring +class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = Qwen3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + indexes: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3ForCausalLM + + >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + indexes=indexes, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=hidden_states, + attentions=outputs.attentions, + ) + + +class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel): + pass + + +class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel): + pass + + +class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel): + base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model` + + +__all__ = [ + "Qwen3ForCausalLM", + "Qwen3ForQuestionAnswering", + "Qwen3PreTrainedModel", + "Qwen3Model", + "Qwen3ForSequenceClassification", + "Qwen3ForTokenClassification", +] \ No newline at end of file diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py index 4a95635d1e..b06db8da3c 100644 --- a/lightllm/server/x2i_server/past_kv_cache_client.py +++ b/lightllm/server/x2i_server/past_kv_cache_client.py @@ -12,17 +12,11 @@ attach_shm_kv_cache_ptr, register_shm_ptr_to_pin, ) +from lightllm.server.core.objs.x2i_params import PastKVCacheItem logger = init_logger(__name__) -@dataclass -class PastKVCacheItem: - req_id: int - token_len: int - page_indexes: List[int] - - class PastKVCacheClient(object): """ This class is responsible for passing kv cache between generation server and model server, @@ -49,7 +43,7 @@ def __init__(self, only_create_meta_data: bool, init_shm_data: bool): self.attach_shm_handle.wait() return - def allocate_pages(self, req_id: int, need_tokens: int) -> List[int]: + def allocate_pages(self, req_id: int, need_tokens: int, img_tokens: int = 0, img_len: int = 0) -> List[int]: need_pages = (need_tokens + self.token_page_size - 1) // self.token_page_size if need_pages > self.page_num: logger.error( @@ -64,7 +58,7 @@ def allocate_pages(self, req_id: int, need_tokens: int) -> List[int]: page_indexes, self.free_pages = self.free_pages[:need_pages], self.free_pages[need_pages:] self.allocated_pages_dict[req_id] = PastKVCacheItem( - req_id=req_id, token_len=need_tokens, page_indexes=page_indexes) + req_id=req_id, token_len=need_tokens, page_indexes=page_indexes, img_tokens=img_tokens, img_len=img_len) return page_indexes @@ -78,7 +72,7 @@ def free_pages_by_req_id(self, req_id: int): def get_pages_by_req_id(self, req_id: int) -> Optional[List[int]]: with self.lock: item = self.allocated_pages_dict.get(req_id, None) - return item.page_indexes if item is not None else None + return item def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optional[torch.Tensor]: if page_indexes is None: From a5a08447345203b7e9253f8ae03123799bd50abf Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Fri, 27 Mar 2026 16:32:18 +0000 Subject: [PATCH 05/41] interleave api. --- lightllm/models/neo_chat_moe/model.py | 16 +- lightllm/server/api_http.py | 12 +- lightllm/server/api_models.py | 100 ++++++- lightllm/server/api_openai.py | 258 ++++++++++++++++-- .../core/objs/token_chunck_hash_list.py | 3 +- lightllm/server/httpserver/manager.py | 5 +- lightllm/server/multimodal_params.py | 3 + lightllm/server/x2i_server/manager.py | 2 +- 8 files changed, 358 insertions(+), 41 deletions(-) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index edabc06075..ec8772cc72 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -42,6 +42,7 @@ def __init__(self, tokenizer, model_cfg, **kwargs): self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag) self.image_end_tag = IMG_END_TOKEN self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag) + self.image_tag = IMG_TOKEN self.conversation_module = self.load_conversion_module(tokenizer.name_or_path) self.template = model_cfg.get("template", "neo1_0") @@ -146,16 +147,19 @@ def fix_prompt(self, prompt: str, img_len: int): def get_query_for_it2i(self, prompt: str): image_len = prompt.count(IMG_TOKEN) - query_condition = self._build_t2i_query(prompt, thinking_content="\n\n\n\n") + # query_condition = self._build_t2i_query(prompt, thinking_content="\n\n\n\n") + query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt query_text_uncondition = self._build_t2i_query(IMG_TOKEN * image_len) question_img_uncondition = self._build_t2i_query("") return query_condition, query_text_uncondition, question_img_uncondition - def get_query_for_t2i(self, prompt): - query_condition = self._build_t2i_query( - f"Please generate an image based on the following description: {prompt}", - thinking_content="\n\n\n\n") - query_uncondition = self._build_t2i_query(f"") + def get_query_for_t2i(self, prompt: str): + # prompt is already applied + query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt + # query_condition = self._build_t2i_query( + # f"Please generate an image based on the following description: {prompt}", + # thinking_content="\n\n\n\n") + query_uncondition = self._build_t2i_query("") return query_condition, query_uncondition diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 4c366fb770..65571bf0fc 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -51,12 +51,13 @@ from lightllm.utils.envs_utils import get_unique_server_name from dataclasses import dataclass -from .api_openai import chat_completions_impl, completions_impl +from .api_openai import chat_completions_impl, completions_impl, chat_completions_impl_v2 from .api_models import ( ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, CompletionResponse, + ChatCompletionRequestV2, ) from .build_prompt import build_prompt, init_tokenizer @@ -274,6 +275,15 @@ async def generate_image(request: Request) -> Response: except Exception as e: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) +@app.post("/v2/chat/completions", response_model=ChatCompletionResponse) +async def completions_v2(request: ChatCompletionRequestV2, raw_request: Request) -> Response: + if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: + return create_error_response( + HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" + ) + + resp = await chat_completions_impl_v2(request, raw_request) + return resp @app.get("/tokens") @app.post("/tokens") diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 3d9a6bc8ed..5725799338 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -3,7 +3,7 @@ import uuid from pydantic import BaseModel, Field, field_validator, model_validator -from typing import Any, Dict, List, Optional, Union, Literal, ClassVar +from typing import Any, Dict, List, Optional, Union, Literal, ClassVar, TypeAlias from transformers import GenerationConfig @@ -275,7 +275,7 @@ class UsageInfo(BaseModel): class ChatMessage(BaseModel): role: Optional[str] = None - content: Optional[str] = None + content: Optional[Union[str, List[MessageContent]]] = None reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) @@ -301,7 +301,7 @@ def ensure_id_is_str(cls, v): class DeltaMessage(BaseModel): role: Optional[str] = None - content: Optional[str] = None + content: Optional[Union[str, List[MessageContent]]] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) reasoning_content: Optional[str] = None @@ -370,3 +370,97 @@ class CompletionStreamResponse(BaseModel): @field_validator("id", mode="before") def ensure_id_is_str(cls, v): return str(v) + + +# Supported values +AspectRatio: TypeAlias = Literal[ + "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", + "9:16", "16:9", "21:9", +] + +ImageSize: TypeAlias = Literal["0.5K", "1K", "2K", "4K"] + +Modality: TypeAlias = Literal["text", "image", "audio"] + +ImageType: TypeAlias = Literal["png", "jpeg", "webp"] + +class ImageConfig(BaseModel): + aspect_ratio: AspectRatio = "1:1" + image_size: ImageSize = "1K" + image_type: ImageType = "jpeg" + + # Mapping to actual resolutions (base resolution for 1K) + _aspect_ratio_to_resolution: ClassVar[dict] = { + "1:1": (1024, 1024), + "2:3": (832, 1248), + "3:2": (1248, 832), + "3:4": (864, 1184), + "4:3": (1184, 864), + "4:5": (896, 1152), + "5:4": (1152, 896), + "9:16": (768, 1344), + "16:9": (1344, 768), + "21:9": (1536, 672), + } + + _size_multiplier: ClassVar[dict] = { + "0.5K": 0.5, + "1K": 1.0, + "2K": 2.0, + "4K": 4.0, + } + + @field_validator("aspect_ratio") + @classmethod + def validate_aspect_ratio(cls, v): + if v not in cls._aspect_ratio_to_resolution: + raise ValueError(f"Unsupported aspect ratio: {v}") + return v + + @field_validator("image_size") + @classmethod + def validate_image_size(cls, v): + if v not in cls._size_multiplier: + raise ValueError(f"Unsupported image size: {v}") + return v + + @field_validator("image_type") + @classmethod + def validate_image_type(cls, v): + if v not in ['jpeg', 'png', 'webp']: + raise ValueError(f"unsupported image type: {v}") + return v + + def get_resolution(self): + """Return scaled resolution (width, height)""" + base = self._aspect_ratio_to_resolution[self.aspect_ratio] + if base is None: + return None # extended ratios don't have fixed base + + scale = self._size_multiplier[self.image_size] + w, h = base + return int(w * scale), int(h * scale) + + +class ChatCompletionRequestV2(ChatCompletionRequest): + modalities: List[Modality] = ["text"] + image_config: Optional[ImageConfig] = None + + @field_validator("modalities") + @classmethod + def validate_modalities(cls, v): + if "text" not in v: + raise ValueError("modalities must include 'text'") + if len(v) != len(set(v)): + raise ValueError("modalities must be unique") + return v + + @model_validator(mode="after") + def validate_image_config(self): + if "image" in self.modalities: + if self.image_config is None: + self.image_config = ImageConfig() + else: + if self.image_config is not None: + raise ValueError("image_config provided but 'image' not in modalities") + return self \ No newline at end of file diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 598bd8f1f2..e40bfaca7d 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -25,6 +25,7 @@ from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect from fastapi.responses import Response, StreamingResponse, JSONResponse from lightllm.server.core.objs.sampling_params import SamplingParams +from lightllm.server.core.objs.x2i_params import X2IParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster @@ -53,6 +54,9 @@ DeltaMessage, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, + ChatCompletionRequestV2, + MessageContent, + ImageURL, ) logger = init_logger(__name__) @@ -159,38 +163,21 @@ def _process_tools_stream(index: int, delta: str, parser_dict: Dict, request: Ch normal_text, calls = parser.parse_stream_chunk(delta) return normal_text, calls - -async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response: - from .api_http import g_objs - - if request.logit_bias is not None: - return create_error_response( - HTTPStatus.BAD_REQUEST, - "The logit_bias parameter is not currently supported", - ) - - if request.function_call != "none": - return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported") - - created_time = int(time.time()) - - multimodal_params_dict = {"images": [], "audios": []} +def _get_images_and_audios(request: ChatCompletionRequest): + images, audios = [], [] for message in request.messages: if isinstance(message.content, list): - texts = [] for content in message.content: - if content.type == "text" and content.text: - texts.append(content.text) - elif content.type == "image_url" and content.image_url is not None: + if content.type == "image_url" and content.image_url is not None: img = content.image_url.url if img.startswith("http://") or img.startswith("https://"): - multimodal_params_dict["images"].append({"type": "url", "data": img}) + images.append({"type": "url", "data": img}) elif img.startswith("data:image"): # "data:image/jpeg;base64,{base64_image}" data_str = img.split(";", 1)[1] if data_str.startswith("base64,"): data = data_str[7:] - multimodal_params_dict["images"].append({"type": "base64", "data": data}) + images.append({"type": "base64", "data": data}) else: raise ValueError("Unrecognized image input.") else: @@ -200,18 +187,19 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req elif content.type == "audio_url" and content.audio_url is not None: audio = content.audio_url.url if audio.startswith("http://") or audio.startswith("https://"): - multimodal_params_dict["audios"].append({"type": "url", "data": audio}) + audios.append({"type": "url", "data": audio}) elif audio.startswith("data:audio"): data_str = audio.split(";", 1)[1] if data_str.startswith("base64,"): data = data_str[7:] - multimodal_params_dict["audios"].append({"type": "base64", "data": data}) + audios.append({"type": "base64", "data": data}) else: raise ValueError("Unrecognized audio input.") else: raise ValueError("Unrecognized audio input. Supports local path, http url, base64.") + return images, audios - tools = None +def _get_tools(request: ChatCompletionRequest): if request.tools and request.tool_choice != "none": # request.skip_special_tokens = False if not isinstance(request.tool_choice, str): @@ -223,6 +211,26 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req else: tools = [item.function.model_dump() for item in request.tools] + +async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response: + from .api_http import g_objs + + if request.logit_bias is not None: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "The logit_bias parameter is not currently supported", + ) + + if request.function_call != "none": + return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported") + + created_time = int(time.time()) + + images, audios = _get_images_and_audios(request) + multimodal_params_dict = {"images": images, "audios": audios} + + tools = _get_tools(request) + prompt = await build_prompt(request, tools) sampling_params_dict = { "do_sample": request.do_sample, @@ -526,6 +534,206 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) +async def _get_text_generator_input(request: ChatCompletionRequest): + from .api_http import g_objs + + images, audios = _get_images_and_audios(request) + multimodal_params_dict = {"images": images, "audios": audios} + + tools = _get_tools(request) + + prompt = await build_prompt(request, tools) + + sampling_params_dict = { + "do_sample": request.do_sample, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "temperature": request.temperature, + "top_p": request.top_p, + "top_k": request.top_k, + "ignore_eos": request.ignore_eos, + "max_new_tokens": request.max_tokens, + "stop_sequences": request.stop, + "n": request.n, + "best_of": request.n, + "add_special_tokens": False, + } + # Structured output handling + if request.response_format: + if request.response_format.type == "json_schema": + obj = request.response_format.json_schema + if obj: + # guided_json takes str instead of dict obj + sampling_params_dict["guided_json"] = json.dumps(obj.json_schema) + elif request.response_format.type == "json_object": + sampling_params_dict["guided_grammar"] = "json" + + sampling_params = SamplingParams() + sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict) + + sampling_params.verify() + multimodal_params = MultimodalParams(**multimodal_params_dict) + + logger.info(f"call text generator with prompt: {prompt} and {sampling_params_dict}") + + return prompt, sampling_params, multimodal_params + + +async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request: Request) -> Response: + from .api_http import g_objs + + if request.logit_bias is not None: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "The logit_bias parameter is not currently supported", + ) + + if request.function_call != "none": + return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported") + + if request.chat_template_kwargs is None: + request.chat_template_kwargs = {} + request.chat_template_kwargs.update({"enable_thinking": False}) + + chat_request: ChatCompletionRequest = ChatCompletionRequest(**request.model_dump()) + + logger.info(f"{type(chat_request)} and {type(request)} and {request.model_dump_json()} and {chat_request.model_dump_json()}") + if "image" not in request.modalities: + return await chat_completions_impl(chat_request, raw_request) + + if not request.stream or request.n != 1: + return create_error_response(HTTPStatus.BAD_REQUEST, "image only support stream api with n = 1") + + image_start_tag = g_objs.httpserver_manager.tokenizer.image_start_tag + image_tag = g_objs.httpserver_manager.tokenizer.image_tag + + stop = chat_request.stop or [] + if isinstance(stop, str): stop = [stop] + chat_request.stop = stop.append(image_start_tag) + + created_time = int(time.time()) + + + prompt, sampling_params, multimodal_params = await _get_text_generator_input(chat_request) + + width, height = request.image_config.get_resolution() + x2i_params_dict = { + "width": width, + "height": height, + } + + x2i_params = X2IParams() + x2i_params.init(**x2i_params_dict) + + async def stream_result() -> AsyncGenerator[bytes, None]: + nonlocal prompt + from .req_id_generator import convert_sub_id_to_group_id + prompt_tokens = 0 + completion_tokens = 0 + finish_reason = None + group_request_id = None + + while True: + need_call_x2i = False + text_generator = g_objs.httpserver_manager.generate( + prompt, sampling_params, multimodal_params.clone(), request=raw_request) + + reasoning_parser_dict = {} + output_chunk = "" + + async for sub_req_id, request_output, metadata, finish_status in text_generator: + prompt_tokens = metadata["prompt_tokens"] + completion_tokens += 1 + if group_request_id is None: + group_request_id = convert_sub_id_to_group_id(sub_req_id) + + index = sub_req_id + delta = request_output + finish_reason = finish_status.get_finish_reason() + if delta == image_start_tag: + need_call_x2i = True + continue + + output_chunk += delta + + # Handle reasoning content + if get_env_start_args().reasoning_parser and request.separate_reasoning: + reasoning_text, delta = _process_reasoning_stream( + index, delta, reasoning_parser_dict, request_output, request + ) + if reasoning_text: + choice_data = ChatCompletionStreamResponseChoice( + index=0, + delta=DeltaMessage(reasoning_content=reasoning_text), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + + delta_message = DeltaMessage(role="assistant", content=delta) + if finish_status.is_finished(): + finish_reason = finish_status.get_finish_reason() + stream_choice = ChatCompletionStreamResponseChoice( + index=0, delta=delta_message, finish_reason=finish_reason + ) + stream_resp = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + model=request.model, + choices=[stream_choice], + ) + yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") + + + if need_call_x2i: + prompt += output_chunk + + images = await g_objs.httpserver_manager.generate_image( + prompt, x2i_params, multimodal_params.clone(), request=raw_request) + + delta_message = DeltaMessage(role="assistant", content=[]) + for image in images: + message_content = MessageContent(type="image_url", image_url=ImageURL(url=image)) + delta_message.content.append(message_content) + prompt += image_tag + multimodal_params.add_image({"type": "base64", "data": image}) + + stream_resp = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + model=request.model, + choices=[ChatCompletionStreamResponseChoice(index=0, delta=delta_message)], + ) + yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") + + else: + break + + if request.stream_options and request.stream_options.include_usage: + usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + usage_chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + choices=[], + model=request.model, + usage=usage, + ) + yield f"data: {usage_chunk.model_dump_json()}\n\n" + + return StreamingResponse(stream_result(), media_type="text/event-stream", background=BackgroundTasks()) + + async def completions_impl(request: CompletionRequest, raw_request: Request) -> Response: from .api_http import g_objs diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py index b16f264936..de43cc4cc6 100644 --- a/lightllm/server/core/objs/token_chunck_hash_list.py +++ b/lightllm/server/core/objs/token_chunck_hash_list.py @@ -77,8 +77,7 @@ def is_full(self): return self.size == LIGHTLLM_TOKEN_HASH_LIST_SIZE def fill(self, data: List[int]): - assert self.size == 0 - assert len(data) <= LIGHTLLM_TOKEN_HASH_LIST_SIZE + assert len(data) <= LIGHTLLM_TOKEN_HASH_LIST_SIZE, f"data size is too large: {len(data)}" self.items[0 : len(data)] = data self.size = len(data) return diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 9502d630f0..aa590aacea 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -477,7 +477,7 @@ async def generation_wrapper(prompt, sample, multimodal, request): generation_wrapper(prompt_condition, sample_params, multimodal_params, request), generation_wrapper(prompt_uncondition, sample_params, multimodal_params, request)]) generation_params.update_t2i(con_gen, uncon_gen) - # use the first reqeust id as the gen image request id + # use the first request id as the gen image request id x2i_req_id = generate_req_ids[0] generation_params.request_id = x2i_req_id @@ -488,7 +488,6 @@ async def generation_wrapper(prompt, sample, multimodal, request): await self.send_to_x2i.send_pyobj(generation_params, protocol=pickle.HIGHEST_PROTOCOL) await req_status.event.wait() - assert req_status.response is not None self.req_id_to_x2i_reqs.pop(x2i_req_id, None) @@ -496,7 +495,7 @@ async def generation_wrapper(prompt, sample, multimodal, request): return req_status.response.images except Exception as e: - logger.error(str(e)) + logger.error(str(e), exc_info=e) return [] finally: diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 999dfc6859..4a3b34e49a 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -170,6 +170,9 @@ async def verify_and_preload(self, request: Request): await audio.preload(request) return + def add_image(self, image: dict): + self.images.append(ImageItem(**image)) + def to_dict(self): ret = {} ret["images"] = [i.to_dict() for i in self.images] diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index 11e53b5d67..828a5ec35e 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -105,7 +105,7 @@ async def loop_for_fwd(self): protocol=pickle.HIGHEST_PROTOCOL) images = [] - logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with x2i_param: {x2i_param}") + logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {x2i_param}") if is_t2i: images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param) else: From d4aaa77a94ff38f4613de08ecc800a75dcd1264e Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Fri, 27 Mar 2026 17:31:20 +0000 Subject: [PATCH 06/41] nit --- lightllm/server/api_openai.py | 10 ++++++++-- lightllm/server/api_start.py | 2 +- lightllm/server/multimodal_params.py | 3 --- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index e40bfaca7d..07920fad19 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -598,7 +598,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request chat_request: ChatCompletionRequest = ChatCompletionRequest(**request.model_dump()) - logger.info(f"{type(chat_request)} and {type(request)} and {request.model_dump_json()} and {chat_request.model_dump_json()}") + # logger.info(f"{type(chat_request)} and {type(request)} and {request.model_dump_json()} and {chat_request.model_dump_json()}") if "image" not in request.modalities: return await chat_completions_impl(chat_request, raw_request) @@ -633,8 +633,10 @@ async def stream_result() -> AsyncGenerator[bytes, None]: completion_tokens = 0 finish_reason = None group_request_id = None + max_image_gen_num = 15 # TODO: make this configurable - while True: + while max_image_gen_num > 0: + max_image_gen_num -= 1 need_call_x2i = False text_generator = g_objs.httpserver_manager.generate( prompt, sampling_params, multimodal_params.clone(), request=raw_request) @@ -698,6 +700,10 @@ async def stream_result() -> AsyncGenerator[bytes, None]: images = await g_objs.httpserver_manager.generate_image( prompt, x2i_params, multimodal_params.clone(), request=raw_request) + if len(images) == 0: + logger.warning(f"No image generated by x2i: {prompt[-100:]}, exit...") + break + delta_message = DeltaMessage(role="assistant", content=[]) for image in images: message_content = MessageContent(type="image_url", image_url=ImageURL(url=image)) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 358ae425b4..5f43f95d65 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -250,7 +250,7 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=12+ node_world_size + args.visual_dp * (args.visual_tp + 1), used_nccl_ports=already_uesd_ports + num=12 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 4a3b34e49a..20efb1f053 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -187,8 +187,6 @@ def to_origin_dict(self): ret["images"] = [i.to_origin_dict() for i in self.images] ret["audios"] = [a.to_origin_dict() for a in self.audios] return ret -<<<<<<< HEAD -======= def free(self): for image in self.images: @@ -199,4 +197,3 @@ def free(self): def clone(self): return MultimodalParams(**self.to_origin_dict()) ->>>>>>> 43a488a8 (add naive x2i backend.) From 99deab60bf051699d2db2f1f6a8010611ae42c11 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Fri, 27 Mar 2026 18:31:19 +0000 Subject: [PATCH 07/41] fix. --- lightllm/server/api_start.py | 28 +++++++++---------- .../model_infer/mode_backend/base_backend.py | 1 - lightllm/server/tokenizer.py | 3 ++ 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 5f43f95d65..5ce80bf419 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -352,20 +352,20 @@ def normal_or_p_d_start(args): ], ) - if args.enable_multimodal_x2i: - from .x2i_server.manager import start_x2i_process, setup_devices - origin_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) - setup_devices(args) - process_manager.start_submodule_processes( - start_funcs=[ - start_x2i_process, - ], - start_args=[ - (args,), - ], - ) - if origin_devices: - os.environ["CUDA_VISIBLE_DEVICES"] = origin_devices + if args.enable_multimodal_x2i: + from .x2i_server.manager import start_x2i_process, setup_devices + origin_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + setup_devices(args) + process_manager.start_submodule_processes( + start_funcs=[ + start_x2i_process, + ], + start_args=[ + (args,), + ], + ) + if origin_devices: + os.environ["CUDA_VISIBLE_DEVICES"] = origin_devices if args.enable_cpu_cache: from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 12e418824b..0cc492de20 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -49,7 +49,6 @@ from .multi_level_kv_cache import MultiLevelKvCacheModule from .past_kv_cache import PastKVCacheModule - class ModeBackend: def __init__(self) -> None: self.shm_req_manager = ShmReqManager() diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 2800bf0f6b..87904aa97d 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -32,6 +32,7 @@ from ..models.internvl.model import InternvlTokenizer from ..models.gemma3.model import Gemma3Tokenizer from ..models.qwen3_omni_moe_thinker.model import QWen3OmniTokenizer +from ..models.neo_chat_moe.model import NeoChatTokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" @@ -122,5 +123,7 @@ def get_tokenizer( tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + elif model_type == "neo_chat": + tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) return tokenizer From f4fccf25bddd471660aa5cc6be5ac2110fb67b4a Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Fri, 3 Apr 2026 08:39:42 +0000 Subject: [PATCH 08/41] add lightx2v & openrouter api --- lightllm/models/neo_chat_moe/infer_struct.py | 2 - lightllm/server/api_cli.py | 8 +- lightllm/server/api_http.py | 22 +- lightllm/server/api_models.py | 29 ++- lightllm/server/api_openai.py | 206 +++++++++++++++--- lightllm/server/core/objs/sampling_params.py | 6 + lightllm/server/core/objs/x2i_params.py | 33 ++- lightllm/server/tokenizer.py | 2 +- lightllm/server/x2i_server/manager.py | 91 ++++---- .../server/x2i_server/past_kv_cache_client.py | 26 ++- lightllm/utils/config_utils.py | 2 + 11 files changed, 326 insertions(+), 101 deletions(-) diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py index add8abda08..1693bcb964 100644 --- a/lightllm/models/neo_chat_moe/infer_struct.py +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -66,8 +66,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: images = p.get("images", []) for img in images: b_image_start_idx.append(img["start_idx"]) - # a = img["start_idx"] - # print(f"img start_idx: {a}") b_image_len.append(img["token_num"]) b_image_thwd.append(img["grid_thwd"]) b_image_nums.append(len(images)) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4c4be78b3c..96697fae22 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -305,7 +305,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--enable_multimodal_x2i", action="store_true", - help="Whether or not to allow to generate images (requird --enable_multimodal)." + help="Whether or not to allow to generate images (requird --enable_multimodal).", ) parser.add_argument( "--x2i_server_used_gpus", @@ -313,6 +313,12 @@ def make_argument_parser() -> argparse.ArgumentParser: default=1, help="Number of GPUs to use for x2i server (requird --enable_multimodal_x2i).", ) + parser.add_argument( + "--x2v_gen_model_config", + type=str, + default=None, + help="Path of the x2v config file.", + ) parser.add_argument( "--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service." ) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 65571bf0fc..4f7e3c18bd 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -89,6 +89,7 @@ def set_args(self, args: StartArgs): if args.enable_multimodal_x2i: from .api_lightllm import lightllm_generate_image + self.g_generate_image_func = lightllm_generate_image setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::api_server") @@ -242,15 +243,15 @@ async def compat_generate(request: Request) -> Response: return await generate(request) -@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) -async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response: - if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: - return create_error_response( - HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" - ) +# @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +# async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response: +# if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: +# return create_error_response( +# HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" +# ) - resp = await chat_completions_impl(request, raw_request) - return resp +# resp = await chat_completions_impl(request, raw_request) +# return resp @app.post("/v1/completions", response_model=CompletionResponse) @@ -263,6 +264,7 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo resp = await completions_impl(request, raw_request) return resp + @app.post("/generate_image") async def generate_image(request: Request) -> Response: if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: @@ -275,7 +277,8 @@ async def generate_image(request: Request) -> Response: except Exception as e: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) -@app.post("/v2/chat/completions", response_model=ChatCompletionResponse) + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def completions_v2(request: ChatCompletionRequestV2, raw_request: Request) -> Response: if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: return create_error_response( @@ -285,6 +288,7 @@ async def completions_v2(request: ChatCompletionRequestV2, raw_request: Request) resp = await chat_completions_impl_v2(request, raw_request) return resp + @app.get("/tokens") @app.post("/tokens") async def tokens(request: Request): diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 5725799338..737e1b5b3c 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -278,6 +278,8 @@ class ChatMessage(BaseModel): content: Optional[Union[str, List[MessageContent]]] = None reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) + # OpenRouter-style: generated images alongside text; content may include "" placeholders + images: Optional[List[MessageContent]] = None class ChatCompletionResponseChoice(BaseModel): @@ -304,6 +306,7 @@ class DeltaMessage(BaseModel): content: Optional[Union[str, List[MessageContent]]] = None tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) reasoning_content: Optional[str] = None + images: Optional[List[MessageContent]] = None class ChatCompletionStreamResponseChoice(BaseModel): @@ -374,20 +377,36 @@ def ensure_id_is_str(cls, v): # Supported values AspectRatio: TypeAlias = Literal[ - "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", - "9:16", "16:9", "21:9", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", ] -ImageSize: TypeAlias = Literal["0.5K", "1K", "2K", "4K"] +ImageSize: TypeAlias = Literal["0.5K", "1K", "2K", "4K"] Modality: TypeAlias = Literal["text", "image", "audio"] ImageType: TypeAlias = Literal["png", "jpeg", "webp"] + class ImageConfig(BaseModel): aspect_ratio: AspectRatio = "1:1" image_size: ImageSize = "1K" image_type: ImageType = "jpeg" + # X2I / diffusion sampling (optional; server defaults apply when omitted) + steps: Optional[int] = None + guidance_scale: Optional[float] = None + image_guidance_scale: Optional[float] = None + seed: Optional[int] = None + num_images: Optional[int] = None + cfg_norm: Optional[Literal["none", "cfg_zero_star", "global", "text_channel", "channel"]] = None # Mapping to actual resolutions (base resolution for 1K) _aspect_ratio_to_resolution: ClassVar[dict] = { @@ -427,7 +446,7 @@ def validate_image_size(cls, v): @field_validator("image_type") @classmethod def validate_image_type(cls, v): - if v not in ['jpeg', 'png', 'webp']: + if v not in ["jpeg", "png", "webp"]: raise ValueError(f"unsupported image type: {v}") return v @@ -463,4 +482,4 @@ def validate_image_config(self): else: if self.image_config is not None: raise ValueError("image_config provided but 'image' not in modalities") - return self \ No newline at end of file + return self diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 07920fad19..9ea4b13bc4 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -163,6 +163,7 @@ def _process_tools_stream(index: int, delta: str, parser_dict: Dict, request: Ch normal_text, calls = parser.parse_stream_chunk(delta) return normal_text, calls + def _get_images_and_audios(request: ChatCompletionRequest): images, audios = [], [] for message in request.messages: @@ -199,6 +200,7 @@ def _get_images_and_audios(request: ChatCompletionRequest): raise ValueError("Unrecognized audio input. Supports local path, http url, base64.") return images, audios + def _get_tools(request: ChatCompletionRequest): if request.tools and request.tool_choice != "none": # request.skip_special_tokens = False @@ -210,6 +212,7 @@ def _get_tools(request: ChatCompletionRequest): ] else: tools = [item.function.model_dump() for item in request.tools] + return tools async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response: @@ -559,7 +562,7 @@ async def _get_text_generator_input(request: ChatCompletionRequest): "best_of": request.n, "add_special_tokens": False, } - # Structured output handling + # Structured output handling if request.response_format: if request.response_format.type == "json_schema": obj = request.response_format.json_schema @@ -580,6 +583,43 @@ async def _get_text_generator_input(request: ChatCompletionRequest): return prompt, sampling_params, multimodal_params +def _raw_image_to_data_url(image: Union[str, bytes], image_type: str) -> str: + mime = {"jpeg": "image/jpeg", "png": "image/png", "webp": "image/webp"}[image_type] + if isinstance(image, bytes): + b64 = base64.b64encode(image).decode("ascii") + else: + if image.startswith("data:"): + return image + b64 = image + return f"data:{mime};base64,{b64}" + + +def _message_contents_from_raw_images(images: List[Any], image_type: str) -> List[MessageContent]: + return [ + MessageContent( + type="image_url", + image_url=ImageURL(url=_raw_image_to_data_url(img, image_type)), + ) + for img in images + ] + + +def _normalize_image_b64_for_multimodal(image: Union[str, bytes]) -> str: + if isinstance(image, bytes): + return base64.b64encode(image).decode("ascii") + return image + + +def _apply_image_generation_stop(chat_request: ChatCompletionRequest, image_start_tag: str) -> None: + stop = chat_request.stop or [] + if isinstance(stop, str): + stop = [stop] + stop = list(stop) + if image_start_tag not in stop: + stop.append(image_start_tag) + chat_request.stop = stop + + async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request: Request) -> Response: from .api_http import g_objs @@ -598,48 +638,135 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request chat_request: ChatCompletionRequest = ChatCompletionRequest(**request.model_dump()) - # logger.info(f"{type(chat_request)} and {type(request)} and {request.model_dump_json()} and {chat_request.model_dump_json()}") if "image" not in request.modalities: return await chat_completions_impl(chat_request, raw_request) - if not request.stream or request.n != 1: - return create_error_response(HTTPStatus.BAD_REQUEST, "image only support stream api with n = 1") + if request.n != 1: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "multimodal image generation only supports n = 1", + ) image_start_tag = g_objs.httpserver_manager.tokenizer.image_start_tag image_tag = g_objs.httpserver_manager.tokenizer.image_tag - stop = chat_request.stop or [] - if isinstance(stop, str): stop = [stop] - chat_request.stop = stop.append(image_start_tag) + _apply_image_generation_stop(chat_request, image_start_tag) created_time = int(time.time()) - prompt, sampling_params, multimodal_params = await _get_text_generator_input(chat_request) - width, height = request.image_config.get_resolution() - x2i_params_dict = { - "width": width, - "height": height, - } - x2i_params = X2IParams() - x2i_params.init(**x2i_params_dict) + x2i_params.init_from_image_config(request.image_config) + + if not request.stream: + from .req_id_generator import convert_sub_id_to_group_id + + full_text = "" + response_images: List[MessageContent] = [] + prompt_tokens = 0 + completion_tokens = 0 + finish_reason: Optional[str] = "stop" + group_request_id = None + max_image_gen_num = 15 # TODO: make this configurable + + while max_image_gen_num > 0: + max_image_gen_num -= 1 + need_call_x2i = False + output_chunk = "" + text_generator = g_objs.httpserver_manager.generate( + prompt, sampling_params, multimodal_params.clone(), request=raw_request + ) + async for sub_req_id, request_output, metadata, finish_status in text_generator: + prompt_tokens = metadata["prompt_tokens"] + completion_tokens += 1 + if group_request_id is None: + group_request_id = convert_sub_id_to_group_id(sub_req_id) + delta = request_output + if finish_status.is_finished(): + finish_reason = finish_status.get_finish_reason() + if delta == image_start_tag: + need_call_x2i = True + continue + output_chunk += delta + + full_text += output_chunk + + if need_call_x2i: + prompt += output_chunk + images = await g_objs.httpserver_manager.generate_image( + prompt, x2i_params, multimodal_params.clone(), request=raw_request + ) + if len(images) == 0: + logger.warning(f"No image generated by x2i: {prompt[-100:]}, exit...") + break + response_images.extend(_message_contents_from_raw_images(images, request.image_config.image_type)) + for image in images: + prompt += image_tag + full_text += image_tag + multimodal_params.add_image({"type": "base64", "data": _normalize_image_b64_for_multimodal(image)}) + else: + break + + reasoning_text = None + text_out = full_text + reasoning_parser = get_env_start_args().reasoning_parser + if reasoning_parser and request.separate_reasoning: + request_enable_reasoning = _get_reasoning_from_request(request) + try: + parser = ReasoningParser( + model_type=reasoning_parser, + stream_reasoning=False, + force_reasoning=request_enable_reasoning, + ) + reasoning_text, text_out = parser.parse_non_stream(full_text) + except Exception as e: + logger.error(f"Reasoning parsing error: {e}") + return create_error_response( + HTTPStatus.BAD_REQUEST, + "Failed to parse reasoning content!", + ) + + usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + chat_message = ChatMessage( + role="assistant", + content=text_out, + reasoning_content=reasoning_text if reasoning_text else "", + images=response_images if response_images else None, + ) + choice = ChatCompletionResponseChoice( + index=0, + message=chat_message, + finish_reason=finish_reason, + ) + return ChatCompletionResponse( + id=group_request_id or f"chatcmpl-{uuid.uuid4().hex}", + created=created_time, + model=request.model, + choices=[choice], + usage=usage, + ) async def stream_result() -> AsyncGenerator[bytes, None]: nonlocal prompt from .req_id_generator import convert_sub_id_to_group_id + prompt_tokens = 0 completion_tokens = 0 finish_reason = None group_request_id = None - max_image_gen_num = 15 # TODO: make this configurable + max_image_gen_num = 15 # TODO: make this configurable while max_image_gen_num > 0: max_image_gen_num -= 1 need_call_x2i = False text_generator = g_objs.httpserver_manager.generate( - prompt, sampling_params, multimodal_params.clone(), request=raw_request) + prompt, sampling_params, multimodal_params.clone(), request=raw_request + ) reasoning_parser_dict = {} output_chunk = "" @@ -659,7 +786,7 @@ async def stream_result() -> AsyncGenerator[bytes, None]: output_chunk += delta - # Handle reasoning content + # Handle reasoning content if get_env_start_args().reasoning_parser and request.separate_reasoning: reasoning_text, delta = _process_reasoning_stream( index, delta, reasoning_parser_dict, request_output, request @@ -678,7 +805,6 @@ async def stream_result() -> AsyncGenerator[bytes, None]: ) yield f"data: {chunk.model_dump_json()}\n\n" - delta_message = DeltaMessage(role="assistant", content=delta) if finish_status.is_finished(): finish_reason = finish_status.get_finish_reason() @@ -691,33 +817,51 @@ async def stream_result() -> AsyncGenerator[bytes, None]: model=request.model, choices=[stream_choice], ) - yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") - + yield ("data: " + json.dumps(stream_resp.model_dump(), ensure_ascii=False) + "\n\n").encode("utf-8") if need_call_x2i: prompt += output_chunk images = await g_objs.httpserver_manager.generate_image( - prompt, x2i_params, multimodal_params.clone(), request=raw_request) + prompt, x2i_params, multimodal_params.clone(), request=raw_request + ) if len(images) == 0: logger.warning(f"No image generated by x2i: {prompt[-100:]}, exit...") break - delta_message = DeltaMessage(role="assistant", content=[]) - for image in images: - message_content = MessageContent(type="image_url", image_url=ImageURL(url=image)) - delta_message.content.append(message_content) - prompt += image_tag - multimodal_params.add_image({"type": "base64", "data": image}) + tag_chunk = ChatCompletionStreamResponse( + id=group_request_id, + created=created_time, + model=request.model, + choices=[ + ChatCompletionStreamResponseChoice( + index=0, + delta=DeltaMessage(role="assistant", content=image_tag), + finish_reason=None, + ) + ], + ) + yield f"data: {tag_chunk.model_dump_json()}\n\n" - stream_resp = ChatCompletionStreamResponse( + img_items = _message_contents_from_raw_images(images, request.image_config.image_type) + img_chunk = ChatCompletionStreamResponse( id=group_request_id, created=created_time, model=request.model, - choices=[ChatCompletionStreamResponseChoice(index=0, delta=delta_message)], + choices=[ + ChatCompletionStreamResponseChoice( + index=0, + delta=DeltaMessage(role="assistant", images=img_items), + finish_reason=None, + ) + ], ) - yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") + yield f"data: {img_chunk.model_dump_json()}\n\n" + + for image in images: + prompt += image_tag + multimodal_params.add_image({"type": "base64", "data": _normalize_image_b64_for_multimodal(image)}) else: break diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 3cf5cb0887..85112a0379 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -293,6 +293,8 @@ class SamplingParams(ctypes.Structure): ("ignore_eos", ctypes.c_bool), # the max number of image patches to be used in the internvl model, for the test ("image_max_patch_num", ctypes.c_int), + ("min_pixels", ctypes.c_int), + ("max_pixels", ctypes.c_int), ("max_new_tokens", ctypes.c_int), ("min_new_tokens", ctypes.c_int), # Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty @@ -346,6 +348,8 @@ def init(self, tokenizer, **kwargs): self.top_k = kwargs.get("top_k", SamplingParams._top_k) self.ignore_eos = kwargs.get("ignore_eos", False) self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) + self.min_pixels = kwargs.get("min_pixels", -1) + self.max_pixels = kwargs.get("max_pixels", -1) self.max_new_tokens = kwargs.get("max_new_tokens", 16384) self.min_new_tokens = kwargs.get("min_new_tokens", 1) self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) @@ -487,6 +491,8 @@ def to_dict(self): "top_k": self.top_k, "ignore_eos": self.ignore_eos, "image_max_patch_num": self.image_max_patch_num, + "min_pixels": self.min_pixels, + "max_pixels": self.max_pixels, "max_new_tokens": self.max_new_tokens, "min_new_tokens": self.min_new_tokens, "exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(), diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py index 50b2595ec2..9a29ef949d 100644 --- a/lightllm/server/core/objs/x2i_params.py +++ b/lightllm/server/core/objs/x2i_params.py @@ -1,9 +1,10 @@ import ctypes from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from enum import IntEnum from .token_chunck_hash_list import PastKVCachePageList + class CfgNormType(IntEnum): NONE = 0 CFG_ZERO_STAR = 1 @@ -52,10 +53,11 @@ class X2IParams(ctypes.Structure): _num_images: int = 1 _cfg_norm: CfgNormType = CfgNormType.NONE - def init(self, **kwargs): + def init(self, **kwargs): def _get(key, default): v = kwargs.get(key) return v if v is not None else default + self.width = _get("width", X2IParams._width) self.height = _get("height", X2IParams._height) self.steps = _get("steps", X2IParams._steps) @@ -70,6 +72,31 @@ def _get(key, default): self.total_prompt_tokens = 0 self.request_id = 0 + def init_from_image_config(self, image_config: Any) -> None: + """从 HTTP `image_config`(api_models.ImageConfig)填充,与 `init(**kwargs)` 共用默认值逻辑。""" + from lightllm.server.api_models import ImageConfig + + if not isinstance(image_config, ImageConfig): + raise TypeError(f"expected ImageConfig, got {type(image_config)!r}") + w, h = image_config.get_resolution() + kwargs: Dict[str, Any] = {"width": w, "height": h} + if image_config.steps is not None: + kwargs["steps"] = image_config.steps + if image_config.guidance_scale is not None: + kwargs["guidance_scale"] = image_config.guidance_scale + if image_config.image_guidance_scale is not None: + kwargs["image_guidance_scale"] = image_config.image_guidance_scale + if image_config.seed is not None: + kwargs["seed"] = image_config.seed + if image_config.num_images is not None: + kwargs["num_images"] = image_config.num_images + if image_config.cfg_norm is not None: + for e in CfgNormType: + if e.as_str() == image_config.cfg_norm: + kwargs["cfg_norm"] = e + break + self.init(**kwargs) + def update(self, past_kv: PastKVCachePageList, meta: Dict): item: PastKVCacheItem = meta.get("kv_cache_item") past_kv.token_len = item.token_len @@ -119,4 +146,4 @@ class PastKVCacheItem: token_len: int img_tokens: int img_len: int - page_indexes: List[int] \ No newline at end of file + page_indexes: List[int] diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 87904aa97d..18f7778bbb 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -125,5 +125,5 @@ def get_tokenizer( tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) elif model_type == "neo_chat": tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) - + print(type(tokenizer), tokenizer, flush=True) return tokenizer diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index 828a5ec35e..5fe817f175 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -22,14 +22,15 @@ logger = init_logger(__name__) -''' +""" manage a generation service, 1. start x2v pipelines 2. receive generation request from http_server. 3. call llm gen to obtain past key values 4. call x2v to generate images and pass the key values to it 5. return the generated images. -''' +""" + class X2IManager: def __init__( @@ -50,30 +51,43 @@ def __init__( self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True) async def wait_to_model_ready(self): - # from lightx2v import LightX2VPipeline - # self.gen_pipe = LightX2VPipeline( - # model_path = self.args.model_dir, - # model_cls = self.args.model_name, - # task="t2i" - # ) - # self.gen_pipe.create_generator( - # config_json = self.args.x2v_gen_model_config, - # ) - - from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I - - self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device()) - + from lightx2v import LightX2VPipeline + + self.gen_pipe = LightX2VPipeline( + model_path=self.args.model_dir, + model_cls="neopp", + support_tasks=["t2i", "i2i"], + ) + self.gen_pipe.create_generator( + config_json=self.args.x2v_gen_model_config, + ) + self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False}) + # from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I + # self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device()) pass async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams): - images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param) - return images + print(past_kv_cache.shape, past_kv_cache_text.shape, param, flush=True) + self.gen_pipe.runner.set_kvcache_t2i(past_kv_cache, past_kv_cache_text) + image = self.gen_pipe.generate( + seed=param.seed, + task="t2i", + save_result_path="", # 返回base64,不需要指定路径了 + target_shape=[param.height, param.width], # Height, Width + ) + # images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param) + return [image] async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams): - images = self.naive_x2i.it2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img, param) - return images - + self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img) + image = self.gen_pipe.generate( + seed=param.seed, + task="i2i", + save_result_path="", # 返回base64,不需要指定路径了 + target_shape=[param.height, param.width], # Height, Width + ) + # images = self.naive_x2i.it2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img, param) + return [image] async def loop_for_fwd(self): while True: @@ -94,15 +108,15 @@ async def loop_for_fwd(self): is_t2i = x2i_param.past_kvcache_img.is_empty() past_kv_cache_img = None - if not is_t2i: # t2i + if not is_t2i: # t2i past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i( x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len ) # release self.send_to_httpserver.send_pyobj( - X2ICacheRelease(request_id=x2i_param.request_id), - protocol=pickle.HIGHEST_PROTOCOL) + X2ICacheRelease(request_id=x2i_param.request_id), protocol=pickle.HIGHEST_PROTOCOL + ) images = [] logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {x2i_param}") @@ -111,20 +125,17 @@ async def loop_for_fwd(self): else: images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param) - self.send_to_httpserver.send_pyobj(X2IResponse( - request_id=x2i_param.request_id, - images=images), - protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_httpserver.send_pyobj( + X2IResponse(request_id=x2i_param.request_id, images=images), protocol=pickle.HIGHEST_PROTOCOL + ) except Exception as e: - self.send_to_httpserver.send_pyobj(X2IResponse( - request_id=x2i_param.request_id, - images=None), - protocol=pickle.HIGHEST_PROTOCOL) + self.send_to_httpserver.send_pyobj( + X2IResponse(request_id=x2i_param.request_id, images=None), protocol=pickle.HIGHEST_PROTOCOL + ) logger.error(e) - async def loop_for_netio_req(self): while True: try: @@ -139,6 +150,7 @@ async def loop_for_netio_req(self): def clean_up(self): pass + def setup_devices(args: StartArgs): devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip() logger.info(f"current devices: {devices} {torch.cuda.device_count()}") @@ -152,11 +164,12 @@ def setup_devices(args: StartArgs): if len(devices) < llm_need_gpus + x2i_need_gpus: raise ValueError(f"devices {devices} not enough, need {llm_need_gpus} and {x2i_need_gpus}") - os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices[ - llm_need_gpus:llm_need_gpus + x2i_need_gpus])) + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices[llm_need_gpus : llm_need_gpus + x2i_need_gpus])) - logger.info(f"setup devices for x2i server: {os.environ['CUDA_VISIBLE_DEVICES']}, " - f"{torch.cuda.device_count()} {torch.cuda.current_device()}") + logger.info( + f"setup devices for x2i server: {os.environ['CUDA_VISIBLE_DEVICES']}, " + f"{torch.cuda.device_count()} {torch.cuda.current_device()}" + ) def start_x2i_process(args, pipe_writer): @@ -166,7 +179,9 @@ def start_x2i_process(args, pipe_writer): start_parent_check_thread() set_current_device_id(torch.cuda.current_device()) try: - x2iserver = X2IManager(args=args,) + x2iserver = X2IManager( + args=args, + ) asyncio.run(x2iserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py index b06db8da3c..56781f49ab 100644 --- a/lightllm/server/x2i_server/past_kv_cache_client.py +++ b/lightllm/server/x2i_server/past_kv_cache_client.py @@ -58,7 +58,8 @@ def allocate_pages(self, req_id: int, need_tokens: int, img_tokens: int = 0, img page_indexes, self.free_pages = self.free_pages[:need_pages], self.free_pages[need_pages:] self.allocated_pages_dict[req_id] = PastKVCacheItem( - req_id=req_id, token_len=need_tokens, page_indexes=page_indexes, img_tokens=img_tokens, img_len=img_len) + req_id=req_id, token_len=need_tokens, page_indexes=page_indexes, img_tokens=img_tokens, img_len=img_len + ) return page_indexes @@ -77,17 +78,20 @@ def get_pages_by_req_id(self, req_id: int) -> Optional[List[int]]: def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optional[torch.Tensor]: if page_indexes is None: return None - assert token_num <= len(page_indexes) * self.token_page_size and \ - token_num > (len(page_indexes) - 1) * self.token_page_size + assert ( + token_num <= len(page_indexes) * self.token_page_size + and token_num > (len(page_indexes) - 1) * self.token_page_size + ) (P, L, S, H, D) = self.cpu_kv_cache_tensor[page_indexes].shape - # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (2, L, H // 2, P, S, D) - # -> (2, L, H // 2, P * S, D) -> ( L, 2, H // 2, P * S, D) - kv = self.cpu_kv_cache_tensor[page_indexes] \ - .view(P, L, S, 2, H // 2, D) \ - .permute(3, 1, 4, 0, 2, 5).contiguous() \ - .view(2, L, H // 2, P * S, D) \ - .permute(1, 0, 2, 3, 4) - return kv[:, :, :, :token_num, :].contiguous() + # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2, H // 2, P, S, D) -> (L, 2, H // 2, P * S, D) + kv = ( + self.cpu_kv_cache_tensor[page_indexes] + .view(P, L, S, 2, H // 2, D) + .permute(1, 3, 0, 2, 4, 5) + .contiguous() + .view(L, 2, P * S, H // 2, D) + ) + return kv def _create_shm_cpu_kv_cache(self): shm_ptr = create_shm_kv_cache_ptr( diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 54f20384b7..a9088d9a2e 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -186,6 +186,8 @@ def has_vision_module(model_path: str) -> bool: ): # Qwen3OmniMoeVisionTransformerPretrainedModel return True + elif model_type == "neo_chat": + return True else: raise Exception("unknown vision model type") except: From 06571d0e78afb3b3fc3915587f5275dc625903ee Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Tue, 7 Apr 2026 06:21:13 +0000 Subject: [PATCH 09/41] fix x2v && add openai image-only --- lightllm/models/neo_chat_moe/model.py | 6 +++ lightllm/server/api_models.py | 4 +- lightllm/server/api_openai.py | 48 +++++++++++++++++-- lightllm/server/x2i_server/manager.py | 21 +++++++- .../server/x2i_server/past_kv_cache_client.py | 2 +- 5 files changed, 74 insertions(+), 7 deletions(-) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index ec8772cc72..900d2eb028 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -48,6 +48,7 @@ def __init__(self, tokenizer, model_cfg, **kwargs): def load_conversion_module(self, model_dir: str): import importlib + conversion_path = os.path.join(model_dir, "conversation.py") if not os.path.exists(conversion_path): return None @@ -151,6 +152,9 @@ def get_query_for_it2i(self, prompt: str): query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt query_text_uncondition = self._build_t2i_query(IMG_TOKEN * image_len) question_img_uncondition = self._build_t2i_query("") + print(f"query_condition: {query_condition}") + print(f"query_text_uncondition: {query_text_uncondition}") + print(f"question_img_uncondition: {question_img_uncondition}") return query_condition, query_text_uncondition, question_img_uncondition def get_query_for_t2i(self, prompt: str): @@ -160,6 +164,8 @@ def get_query_for_t2i(self, prompt: str): # f"Please generate an image based on the following description: {prompt}", # thinking_content="\n\n\n\n") query_uncondition = self._build_t2i_query("") + print(f"query_condition: {query_condition}", flush=True) + print(f"query_uncondition: {query_uncondition}", flush=True) return query_condition, query_uncondition diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 737e1b5b3c..643bba66e3 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -468,8 +468,8 @@ class ChatCompletionRequestV2(ChatCompletionRequest): @field_validator("modalities") @classmethod def validate_modalities(cls, v): - if "text" not in v: - raise ValueError("modalities must include 'text'") + if "text" not in v and v != ["image"]: + raise ValueError("modalities must include 'text', or be ['image'] for image-only generation") if len(v) != len(set(v)): raise ValueError("modalities must be unique") return v diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 9ea4b13bc4..0f364cdc1e 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -202,6 +202,7 @@ def _get_images_and_audios(request: ChatCompletionRequest): def _get_tools(request: ChatCompletionRequest): + tools = None if request.tools and request.tool_choice != "none": # request.skip_special_tokens = False if not isinstance(request.tool_choice, str): @@ -537,15 +538,17 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) -async def _get_text_generator_input(request: ChatCompletionRequest): +async def _get_text_generator_input(request: ChatCompletionRequest, apply_chat_template: bool = True): from .api_http import g_objs images, audios = _get_images_and_audios(request) multimodal_params_dict = {"images": images, "audios": audios} tools = _get_tools(request) - - prompt = await build_prompt(request, tools) + if apply_chat_template: + prompt = await build_prompt(request, tools) + else: + prompt = request.messages[-1].content sampling_params_dict = { "do_sample": request.do_sample, @@ -620,6 +623,42 @@ def _apply_image_generation_stop(chat_request: ChatCompletionRequest, image_star chat_request.stop = stop +async def _chat_completion_image_only( + request: ChatCompletionRequestV2, + raw_request: Request, +) -> ChatCompletionResponse: + from .api_http import g_objs + + created_time = int(time.time()) + x2i_params = X2IParams() + x2i_params.init_from_image_config(request.image_config) + + prompt, _, multimodal_params = await _get_text_generator_input(request, apply_chat_template=False) + + images = await g_objs.httpserver_manager.generate_image( + prompt, x2i_params, multimodal_params.clone(), request=raw_request + ) + response_images = _message_contents_from_raw_images(images, request.image_config.image_type) + chat_message = ChatMessage( + role="assistant", + content=prompt, + images=response_images if response_images else None, + ) + choice = ChatCompletionResponseChoice( + index=0, + message=chat_message, + finish_reason="stop", + ) + usage = UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0) + return ChatCompletionResponse( + id=f"chatcmpl-{uuid.uuid4().hex}", + created=created_time, + model=request.model, + choices=[choice], + usage=usage, + ) + + async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request: Request) -> Response: from .api_http import g_objs @@ -641,6 +680,9 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request if "image" not in request.modalities: return await chat_completions_impl(chat_request, raw_request) + if request.modalities == ["image"]: + return await _chat_completion_image_only(request, raw_request) + if request.n != 1: return create_error_response( HTTPStatus.BAD_REQUEST, diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index 5fe817f175..e9e5e6de4b 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -67,7 +67,12 @@ async def wait_to_model_ready(self): pass async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams): - print(past_kv_cache.shape, past_kv_cache_text.shape, param, flush=True) + past_kv_cache = self._truncate_kv_cache_to_compressed_len( + past_kv_cache, param.past_kvcache.get_compressed_len() + ) + past_kv_cache_text = self._truncate_kv_cache_to_compressed_len( + past_kv_cache_text, param.past_kvcache_text.get_compressed_len() + ) self.gen_pipe.runner.set_kvcache_t2i(past_kv_cache, past_kv_cache_text) image = self.gen_pipe.generate( seed=param.seed, @@ -79,6 +84,15 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams return [image] async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams): + past_kv_cache = self._truncate_kv_cache_to_compressed_len( + past_kv_cache, param.past_kvcache.get_compressed_len() + ) + past_kv_cache_text = self._truncate_kv_cache_to_compressed_len( + past_kv_cache_text, param.past_kvcache_text.get_compressed_len() + ) + past_kv_cache_img = self._truncate_kv_cache_to_compressed_len( + past_kv_cache_img, param.past_kvcache_img.get_compressed_len() + ) self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img) image = self.gen_pipe.generate( seed=param.seed, @@ -150,6 +164,11 @@ async def loop_for_netio_req(self): def clean_up(self): pass + def _truncate_kv_cache_to_compressed_len(self, kv: torch.Tensor, compressed_len: int) -> torch.Tensor: + seq = kv.shape[2] + n = min(compressed_len, seq) + return kv[:, :, :n:, :].contiguous() + def setup_devices(args: StartArgs): devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip() diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py index 56781f49ab..9e9f0d99d4 100644 --- a/lightllm/server/x2i_server/past_kv_cache_client.py +++ b/lightllm/server/x2i_server/past_kv_cache_client.py @@ -83,7 +83,7 @@ def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optio and token_num > (len(page_indexes) - 1) * self.token_page_size ) (P, L, S, H, D) = self.cpu_kv_cache_tensor[page_indexes].shape - # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2, H // 2, P, S, D) -> (L, 2, H // 2, P * S, D) + # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2,P, S, H // 2, D) -> (L, 2, P * S, H // 2, D) kv = ( self.cpu_kv_cache_tensor[page_indexes] .view(P, L, S, 2, H // 2, D) From 0619f9dcab86fb3daf65f67887b4bcf61249ae9d Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 8 Apr 2026 07:09:02 +0000 Subject: [PATCH 10/41] add naive option. --- lightllm/server/api_cli.py | 5 + lightllm/server/core/objs/start_args_type.py | 1 + lightllm/server/core/objs/x2i_params.py | 7 +- lightllm/server/x2i_server/manager.py | 25 +++-- .../x2i_server/naive/modeling_neo_chat.py | 101 +++++++++++++----- .../server/x2i_server/past_kv_cache_client.py | 27 +++-- 6 files changed, 121 insertions(+), 45 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 96697fae22..5cf656d440 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -313,6 +313,11 @@ def make_argument_parser() -> argparse.ArgumentParser: default=1, help="Number of GPUs to use for x2i server (requird --enable_multimodal_x2i).", ) + parser.add_argument( + "--x2i_use_naive_impl", + action="store_true", + help="Whether to use the native backend for x2i generation. If set, it will use the naive pytorch backend mainly for testing and debugging purpose.", + ) parser.add_argument( "--x2v_gen_model_config", type=str, diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 5cf13ef19a..7429b22593 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -163,6 +163,7 @@ class StartArgs: x2i_port: int = field(default=None) http_server_port_for_x2i: int = field(default=None) x2i_server_used_gpus: int = field(default=1) + x2i_use_naive_impl: bool = field(default=False) # multi_modal enable_multimodal_x2i: bool = field(default=False) diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py index 9a29ef949d..3e4498ab4b 100644 --- a/lightllm/server/core/objs/x2i_params.py +++ b/lightllm/server/core/objs/x2i_params.py @@ -37,6 +37,7 @@ class X2IParams(ctypes.Structure): ("seed", ctypes.c_int), ("num_images", ctypes.c_int), ("cfg_norm", ctypes.c_int), + ("timestep_shift", ctypes.c_float), ("past_kvcache", PastKVCachePageList), ("past_kvcache_text", PastKVCachePageList), ("past_kvcache_img", PastKVCachePageList), @@ -47,11 +48,12 @@ class X2IParams(ctypes.Structure): _width: int = 1024 _height: int = 1024 _steps: int = 50 - _guidance_scale: float = 7.0 - _image_guidance_scale: float = 7.0 + _guidance_scale: float = 4.0 + _image_guidance_scale: float = 1.0 _seed: int = 42 _num_images: int = 1 _cfg_norm: CfgNormType = CfgNormType.NONE + _timestep_shift: float = 3.0 def init(self, **kwargs): def _get(key, default): @@ -66,6 +68,7 @@ def _get(key, default): self.seed = _get("seed", X2IParams._seed) self.num_images = _get("num_images", X2IParams._num_images) self.cfg_norm = _get("cfg_norm", X2IParams._cfg_norm) + self.timestep_shift = _get("timestep_shift", X2IParams._timestep_shift) self.past_kvcache = PastKVCachePageList() self.past_kvcache_text = PastKVCachePageList() self.past_kvcache_img = PastKVCachePageList() diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index e9e5e6de4b..7dac4f774a 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -50,7 +50,14 @@ def __init__( self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True) + self.use_naive_x2i = args.x2i_use_naive_impl + async def wait_to_model_ready(self): + if self.use_naive_x2i: + from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I + self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device()) + return + from lightx2v import LightX2VPipeline self.gen_pipe = LightX2VPipeline( @@ -62,11 +69,12 @@ async def wait_to_model_ready(self): config_json=self.args.x2v_gen_model_config, ) self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False}) - # from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I - # self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device()) - pass async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams): + if self.use_naive_x2i: + images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param) + return images + past_kv_cache = self._truncate_kv_cache_to_compressed_len( past_kv_cache, param.past_kvcache.get_compressed_len() ) @@ -84,6 +92,10 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams return [image] async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams): + if self.use_naive_x2i: + images = self.naive_x2i.it2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img, param) + return images + past_kv_cache = self._truncate_kv_cache_to_compressed_len( past_kv_cache, param.past_kvcache.get_compressed_len() ) @@ -100,7 +112,6 @@ async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_i save_result_path="", # 返回base64,不需要指定路径了 target_shape=[param.height, param.width], # Height, Width ) - # images = self.naive_x2i.it2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img, param) return [image] async def loop_for_fwd(self): @@ -113,18 +124,18 @@ async def loop_for_fwd(self): x2i_param = self.waiting_reqs.pop(0) past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i( - x2i_param.past_kvcache.get_all(), x2i_param.past_kvcache.token_len + x2i_param.past_kvcache.get_all(), x2i_param.past_kvcache.token_len, self.use_naive_x2i ) past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i( - x2i_param.past_kvcache_text.get_all(), x2i_param.past_kvcache_text.token_len + x2i_param.past_kvcache_text.get_all(), x2i_param.past_kvcache_text.token_len, self.use_naive_x2i ) is_t2i = x2i_param.past_kvcache_img.is_empty() past_kv_cache_img = None if not is_t2i: # t2i past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i( - x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len + x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len, self.use_naive_x2i ) # release diff --git a/lightllm/server/x2i_server/naive/modeling_neo_chat.py b/lightllm/server/x2i_server/naive/modeling_neo_chat.py index 890d761391..326e2a574d 100644 --- a/lightllm/server/x2i_server/naive/modeling_neo_chat.py +++ b/lightllm/server/x2i_server/naive/modeling_neo_chat.py @@ -372,7 +372,7 @@ def it2i_generate(self, img_cfg_scale=1, cfg_norm='none', enable_timestep_shift=True, - timestep_shift=1, + timestep_shift=3, image_size=(256, 256), num_steps=30, cfg_interval=(0.1, 1.0), @@ -421,7 +421,7 @@ def it2i_generate(self, # init noise image tokens grid_h = image_size[1] // self.patch_size grid_w = image_size[0] // self.patch_size - grid_hw = torch.tensor([[grid_h, grid_w]]*batch_size, device=device) + grid_hw = torch.tensor([[grid_h, grid_w]] * batch_size, device=device) noise_scale = self.noise_scale if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'): @@ -446,6 +446,7 @@ def it2i_generate(self, for step_i in range(num_steps): t = timesteps[step_i] t_next = timesteps[step_i + 1] + use_cfg = t >= cfg_interval[0] and t <= cfg_interval[1] z = self.patchify(image_prediction, self.patch_size * merge_size) image_input = self.patchify(image_prediction, self.patch_size, channel_first=True) @@ -459,27 +460,69 @@ def it2i_generate(self, image_embeds = image_embeds + timestep_embeddings v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size) - if t > cfg_interval[0] and t < cfg_interval[1]: - if cfg_scale > 1: - v_pred_text_uncondition = self._t2i_predict_v(image_embeds, indexes_image_text_uncondition, attention_mask_text_uncondition, past_key_values_text_uncondition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size) - else: - v_pred_text_uncondition = 0 - if img_cfg_scale > 1: - v_pred_img_uncondition = self._t2i_predict_v(image_embeds, indexes_image_img_uncondition, attention_mask_img_uncondition, past_key_values_img_uncondition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size) - else: - v_pred_img_uncondition = 0 + if not use_cfg: + v_pred = v_pred_condition + elif cfg_scale == 1 and img_cfg_scale == 1: + v_pred = v_pred_condition + elif img_cfg_scale == 1: + out_img_cond = self._t2i_predict_v( + image_embeds, + indexes_image_text_uncondition, + attention_mask_text_uncondition, + past_key_values_text_uncondition, + t, + z, + image_token_num=token_h * token_w, + timestep_embeddings=timestep_embeddings, + image_size=image_size, + ) + v_pred = out_img_cond + cfg_scale * (v_pred_condition - out_img_cond) + elif cfg_scale == img_cfg_scale: + out_uncond = self._t2i_predict_v( + image_embeds, + indexes_image_img_uncondition, + attention_mask_img_uncondition, + past_key_values_img_uncondition, + t, + z, + image_token_num=token_h * token_w, + timestep_embeddings=timestep_embeddings, + image_size=image_size, + ) + v_pred = out_uncond + cfg_scale *(v_pred_condition - out_uncond) + else: + out_img_cond = self._t2i_predict_v( + image_embeds, + indexes_image_text_uncondition, + attention_mask_text_uncondition, + past_key_values_text_uncondition, + t, + z, + image_token_num=token_h * token_w, + timestep_embeddings=timestep_embeddings, + image_size=image_size, + ) + out_uncond = self._t2i_predict_v( + image_embeds, + indexes_image_img_uncondition, + attention_mask_img_uncondition, + past_key_values_img_uncondition, + t, + z, + image_token_num=token_h * token_w, + timestep_embeddings=timestep_embeddings, + image_size=image_size, + ) + v_pred = ( + out_uncond + + cfg_scale * (v_pred_condition - out_img_cond) + + img_cfg_scale * (out_img_cond - out_uncond) + ) - if t > cfg_interval[0] and t < cfg_interval[1]: - v_pred_text = v_pred_text_uncondition + cfg_scale * (v_pred_condition - v_pred_text_uncondition) - if cfg_norm == 'text_channel': - norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True) - norm_v_cfg = torch.norm(v_pred_text, dim=-1, keepdim=True) - scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) - v_pred_text = v_pred_text * scale - v_pred = v_pred_img_uncondition + img_cfg_scale * (v_pred_text - v_pred_img_uncondition) + if cfg_scale > 1 or img_cfg_scale > 1: if cfg_norm == 'global': - norm_v_condition = torch.norm(v_pred_condition, dim=(1,2), keepdim=True) - norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True) + norm_v_condition = torch.norm(v_pred_condition, dim=(1, 2), keepdim=True) + norm_v_cfg = torch.norm(v_pred, dim=(1, 2), keepdim=True) scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) v_pred = v_pred * scale elif cfg_norm == 'channel': @@ -488,8 +531,6 @@ def it2i_generate(self, scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) v_pred = v_pred * scale - else: - v_pred = v_pred_condition z = z + (t_next - t) * v_pred @@ -517,7 +558,7 @@ def t2i_generate(self, past_key_values_uncondition, text_lens, cfg_scale=1, - timestep_shift=1, + timestep_shift=3, enable_timestep_shift=True, cfg_norm='none', image_size=(256, 256), @@ -601,7 +642,7 @@ def t2i_generate(self, timestep_embeddings=timestep_embeddings, image_size=image_size) - if t > cfg_interval[0] and t < cfg_interval[1] and cfg_scale > 1: + if t >= cfg_interval[0] and t <= cfg_interval[1] and cfg_scale > 1: v_pred_uncondition = self._t2i_predict_v(image_embeds, indexes_image_uncondition, attention_mask_uncondition, past_key_values_uncondition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings, image_size=image_size) if cfg_norm == 'cfg_zero_star': @@ -623,6 +664,12 @@ def t2i_generate(self, norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True) scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) v_pred = v_pred * scale + elif cfg_norm == 'channel': + norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True) + norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True) + scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) + v_pred = v_pred * scale + else: v_pred = v_pred_condition @@ -718,7 +765,8 @@ def t2i(self, past_kv, past_kv_txt, param: X2IParams): cfg_scale=param.guidance_scale, image_size=(param.width, param.height), num_steps=param.steps, - batch_size=param.num_images) + batch_size=param.num_images, + timestep_shift=param.timestep_shift) return self._post_process(output) @@ -750,6 +798,7 @@ def it2i(self, past_kv, past_kv_txt, past_kv_img, param: X2IParams): image_size=(param.width, param.height), num_steps=param.steps, batch_size=param.num_images, + timestep_shift=param.timestep_shift, ) return self._post_process(output) diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py index 9e9f0d99d4..707a3643c7 100644 --- a/lightllm/server/x2i_server/past_kv_cache_client.py +++ b/lightllm/server/x2i_server/past_kv_cache_client.py @@ -75,7 +75,7 @@ def get_pages_by_req_id(self, req_id: int) -> Optional[List[int]]: item = self.allocated_pages_dict.get(req_id, None) return item - def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optional[torch.Tensor]: + def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int, use_naive=False) -> Optional[torch.Tensor]: if page_indexes is None: return None assert ( @@ -83,15 +83,22 @@ def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optio and token_num > (len(page_indexes) - 1) * self.token_page_size ) (P, L, S, H, D) = self.cpu_kv_cache_tensor[page_indexes].shape - # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2,P, S, H // 2, D) -> (L, 2, P * S, H // 2, D) - kv = ( - self.cpu_kv_cache_tensor[page_indexes] - .view(P, L, S, 2, H // 2, D) - .permute(1, 3, 0, 2, 4, 5) - .contiguous() - .view(L, 2, P * S, H // 2, D) - ) - return kv + if not use_naive: + # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2,P, S, H // 2, D) -> (L, 2, P * S, H // 2, D) + kv = ( + self.cpu_kv_cache_tensor[page_indexes] + .view(P, L, S, 2, H // 2, D) + .permute(1, 3, 0, 2, 4, 5) + .contiguous() + .view(L, 2, P * S, H // 2, D) + ) + return kv + else: + kv = self.cpu_kv_cache_tensor[page_indexes] \ + .view(P, L, S, 2, H // 2, D) \ + .permute(1, 3, 4, 0, 2, 5).contiguous() \ + .view(L, 2, H // 2, P * S, D) + return kv[:, :, :, :token_num, :].contiguous() def _create_shm_cpu_kv_cache(self): shm_ptr = create_shm_kv_cache_ptr( From e958f11d6e95b83072441a44f2b3a3f0640247f8 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 8 Apr 2026 08:01:00 +0000 Subject: [PATCH 11/41] x2v support --- lightllm/models/neo_chat_moe/model.py | 11 +--- lightllm/server/api_cli.py | 2 +- lightllm/server/api_openai.py | 23 ++++---- lightllm/server/core/objs/start_args_type.py | 6 ++- lightllm/server/core/objs/x2i_params.py | 12 +++-- lightllm/server/httpserver/manager.py | 52 +++++++++++-------- lightllm/server/x2i_server/manager.py | 35 ++++--------- .../server/x2i_server/past_kv_cache_client.py | 4 +- 8 files changed, 69 insertions(+), 76 deletions(-) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index 900d2eb028..648d696877 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -152,20 +152,13 @@ def get_query_for_it2i(self, prompt: str): query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt query_text_uncondition = self._build_t2i_query(IMG_TOKEN * image_len) question_img_uncondition = self._build_t2i_query("") - print(f"query_condition: {query_condition}") - print(f"query_text_uncondition: {query_text_uncondition}") - print(f"question_img_uncondition: {question_img_uncondition}") return query_condition, query_text_uncondition, question_img_uncondition def get_query_for_t2i(self, prompt: str): # prompt is already applied + image_len = prompt.count(IMG_TOKEN) query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt - # query_condition = self._build_t2i_query( - # f"Please generate an image based on the following description: {prompt}", - # thinking_content="\n\n\n\n") - query_uncondition = self._build_t2i_query("") - print(f"query_condition: {query_condition}", flush=True) - print(f"query_uncondition: {query_uncondition}", flush=True) + query_uncondition = self._build_t2i_query("", thinking_content=IMG_TOKEN * image_len) return query_condition, query_uncondition diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 96697fae22..78a3355872 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -637,7 +637,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--cpu_cache_storage_size", type=float, - default=2, + default=50, help="""The capacity of cpu cache. GB used.""", ) parser.add_argument( diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 0f364cdc1e..94de694db9 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -538,17 +538,14 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) -async def _get_text_generator_input(request: ChatCompletionRequest, apply_chat_template: bool = True): +async def _get_text_generator_input(request: ChatCompletionRequest): from .api_http import g_objs images, audios = _get_images_and_audios(request) multimodal_params_dict = {"images": images, "audios": audios} tools = _get_tools(request) - if apply_chat_template: - prompt = await build_prompt(request, tools) - else: - prompt = request.messages[-1].content + prompt = await build_prompt(request, tools) sampling_params_dict = { "do_sample": request.do_sample, @@ -626,15 +623,13 @@ def _apply_image_generation_stop(chat_request: ChatCompletionRequest, image_star async def _chat_completion_image_only( request: ChatCompletionRequestV2, raw_request: Request, + prompt: str, + multimodal_params: MultimodalParams, + x2i_params: X2IParams, ) -> ChatCompletionResponse: from .api_http import g_objs created_time = int(time.time()) - x2i_params = X2IParams() - x2i_params.init_from_image_config(request.image_config) - - prompt, _, multimodal_params = await _get_text_generator_input(request, apply_chat_template=False) - images = await g_objs.httpserver_manager.generate_image( prompt, x2i_params, multimodal_params.clone(), request=raw_request ) @@ -680,9 +675,6 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request if "image" not in request.modalities: return await chat_completions_impl(chat_request, raw_request) - if request.modalities == ["image"]: - return await _chat_completion_image_only(request, raw_request) - if request.n != 1: return create_error_response( HTTPStatus.BAD_REQUEST, @@ -701,6 +693,9 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request x2i_params = X2IParams() x2i_params.init_from_image_config(request.image_config) + if request.modalities == ["image"]: + return await _chat_completion_image_only(request, raw_request, prompt, multimodal_params, x2i_params) + if not request.stream: from .req_id_generator import convert_sub_id_to_group_id @@ -710,7 +705,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request completion_tokens = 0 finish_reason: Optional[str] = "stop" group_request_id = None - max_image_gen_num = 15 # TODO: make this configurable + max_image_gen_num = 2 # TODO: make this configurable while max_image_gen_num > 0: max_image_gen_num -= 1 diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 5cf13ef19a..6a8d2d2eb3 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -125,7 +125,9 @@ class StartArgs: vit_att_backend: List[str] = field( default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]} ) - llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"]}) + llm_kv_type: str = field( + default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"]} + ) llm_kv_quant_group_size: int = field(default=8) sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]}) penalty_counter_mode: str = field( @@ -144,7 +146,7 @@ class StartArgs: nixl_pd_kv_page_size: int = field(default=1024) pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) - cpu_cache_storage_size: float = field(default=2) + cpu_cache_storage_size: float = field(default=20) cpu_cache_token_page_size: int = field(default=64) enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py index 9a29ef949d..7704217a39 100644 --- a/lightllm/server/core/objs/x2i_params.py +++ b/lightllm/server/core/objs/x2i_params.py @@ -36,6 +36,8 @@ class X2IParams(ctypes.Structure): ("image_guidance_scale", ctypes.c_float), ("seed", ctypes.c_int), ("num_images", ctypes.c_int), + ("cfg_interval", ctypes.c_float * 2), + ("timestep_shift", ctypes.c_float), ("cfg_norm", ctypes.c_int), ("past_kvcache", PastKVCachePageList), ("past_kvcache_text", PastKVCachePageList), @@ -47,11 +49,13 @@ class X2IParams(ctypes.Structure): _width: int = 1024 _height: int = 1024 _steps: int = 50 - _guidance_scale: float = 7.0 - _image_guidance_scale: float = 7.0 + _guidance_scale: float = 4.0 + _image_guidance_scale: float = 1.0 _seed: int = 42 _num_images: int = 1 - _cfg_norm: CfgNormType = CfgNormType.NONE + _cfg_norm: CfgNormType = CfgNormType.GLOBAL + _cfg_interval: float = (-1, 2) + _timestep_shift: float = 3.0 def init(self, **kwargs): def _get(key, default): @@ -66,6 +70,8 @@ def _get(key, default): self.seed = _get("seed", X2IParams._seed) self.num_images = _get("num_images", X2IParams._num_images) self.cfg_norm = _get("cfg_norm", X2IParams._cfg_norm) + self.cfg_interval = _get("cfg_interval", X2IParams._cfg_interval) + self.timestep_shift = _get("timestep_shift", X2IParams._timestep_shift) self.past_kvcache = PastKVCachePageList() self.past_kvcache_text = PastKVCachePageList() self.past_kvcache_img = PastKVCachePageList() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index aa590aacea..02ea465e32 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -99,6 +99,7 @@ def __init__( if args.enable_multimodal_x2i: from lightllm.server.x2i_server.past_kv_cache_client import PastKVCacheClient + self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=True, init_shm_data=False) self.send_to_x2i = context.socket(zmq.PUSH) self.send_to_x2i.connect(f"{args.zmq_mode}127.0.0.1:{args.x2i_port}") @@ -298,12 +299,10 @@ async def generate( start_time = time.time() request_headers = request.headers if request is not None else {} group_request_id = self.alloc_req_id(sampling_params, is_health_req) - try: original_multimodal_params = None if self.is_multinode_tp_master: original_multimodal_params = copy.deepcopy(multimodal_params) - if self.pd_mode.is_P_or_NORMAL(): await multimodal_params.verify_and_preload(request) @@ -342,7 +341,6 @@ async def generate( # 如果 decode 节点的 ready_kv_len 和 prefill encode 的 len(prompt ids) -1 相等,说明不需要进行 prefill # 直接 raise NixlPrefillNodeStopGenToken raise NixlPrefillNodeStopGenToken(group_request_id=group_request_id) - # 申请资源并存储 alloced_req_indexes = [] while len(alloced_req_indexes) < sampling_params.n: @@ -372,10 +370,10 @@ async def generate( img_tokens = sum([img.token_num for img in multimodal_params.images]) img_len = len(multimodal_params.images) kv_pages = self.past_kv_cache_client.allocate_pages( - req_obj.request_id, req_obj.input_len, img_tokens, img_len) + req_obj.request_id, req_obj.input_len, img_tokens, img_len + ) req_obj.past_kv_cache_page_indexes.fill(kv_pages) - req_objs.append(req_obj) logger.debug( @@ -436,13 +434,13 @@ async def generate( raise e return - - async def generate_image(self, prompt: str, generation_params: X2IParams, multimodal_params: MultimodalParams, request: Request): + async def generate_image( + self, prompt: str, generation_params: X2IParams, multimodal_params: MultimodalParams, request: Request + ): generate_req_ids = [] + async def generation_wrapper(prompt, sample, multimodal, request): - async for sub_req_id, _, metadata, finish_status in self.generate( - prompt, sample, multimodal, request - ): + async for sub_req_id, _, metadata, finish_status in self.generate(prompt, sample, multimodal, request): kv_cache_item: PastKVCacheItem = self.past_kv_cache_client.get_pages_by_req_id(sub_req_id) if kv_cache_item is None: raise Exception(f"kv_cache_pages is None for sub_req_id {sub_req_id}") @@ -457,31 +455,41 @@ async def generation_wrapper(prompt, sample, multimodal, request): sample_params = SamplingParams() sample_params.init(self.tokenizer, **{"img_gen_prefill": True}) img_len = len(multimodal_params.images) + image_guidance_scale = generation_params.image_guidance_scale - if img_len > 0: + if img_len > 0 and image_guidance_scale != 1.0: # call it2i # fix prompt, add tag if img_len greater than s in prompt + # this branch is not used now, because image_guidance_scale is always recommended to be 1.0 prompt = self.tokenizer.fix_prompt(prompt, img_len) - prompt_condition, prompt_text_uncondition, prompt_img_uncondition = self.tokenizer.get_query_for_it2i(prompt) - (con_gen, text_uncon_gen, img_uncon_gen) = await asyncio.gather(*[ - generation_wrapper(prompt_condition, sample_params, multimodal_params, request), - generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params.clone(), request), - generation_wrapper(prompt_img_uncondition, sample_params, MultimodalParams(), request)]) + prompt_condition, prompt_text_uncondition, prompt_img_uncondition = self.tokenizer.get_query_for_it2i( + prompt + ) + (con_gen, text_uncon_gen, img_uncon_gen) = await asyncio.gather( + *[ + generation_wrapper(prompt_condition, sample_params, multimodal_params, request), + generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params.clone(), request), + generation_wrapper(prompt_img_uncondition, sample_params, MultimodalParams(), request), + ] + ) generation_params.update_it2i(con_gen, text_uncon_gen, img_uncon_gen) else: # call t2i prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt) logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}") - (con_gen, uncon_gen) = await asyncio.gather(*[ - generation_wrapper(prompt_condition, sample_params, multimodal_params, request), - generation_wrapper(prompt_uncondition, sample_params, multimodal_params, request)]) + (con_gen, uncon_gen) = await asyncio.gather( + *[ + generation_wrapper(prompt_condition, sample_params, multimodal_params, request), + generation_wrapper(prompt_uncondition, sample_params, multimodal_params.clone(), request), + ] + ) generation_params.update_t2i(con_gen, uncon_gen) # use the first request id as the gen image request id x2i_req_id = generate_req_ids[0] generation_params.request_id = x2i_req_id - req_status = X2IReqStatus(generation_params, generate_req_ids) + req_status = X2IReqStatus(generation_params, generate_req_ids) self.req_id_to_x2i_reqs[generation_params.request_id] = req_status # send generation_params to generation server for image generation @@ -859,7 +867,6 @@ async def loop_for_x2i(self): except Exception as e: logger.error(e, exc_info=e) - async def handle_loop(self): self.recycle_event = asyncio.Event() asyncio.create_task(self.recycle_resource_loop()) @@ -964,9 +971,10 @@ def can_release(self): return False return True + class X2IReqStatus: def __init__(self, req_param: X2IParams, req_ids: List[int]): self.req = X2IParams self.req_ids = req_ids self.event = asyncio.Event() - self.response: X2IResponse = None \ No newline at end of file + self.response: X2IResponse = None diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index e9e5e6de4b..a35ba51b09 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -15,7 +15,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name -from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease +from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease, CfgNormType from lightllm.utils.dist_utils import set_current_device_id from .past_kv_cache_client import PastKVCacheClient @@ -61,22 +61,24 @@ async def wait_to_model_ready(self): self.gen_pipe.create_generator( config_json=self.args.x2v_gen_model_config, ) - self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False}) + self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False}) + # from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I # self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device()) pass async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams): - past_kv_cache = self._truncate_kv_cache_to_compressed_len( - past_kv_cache, param.past_kvcache.get_compressed_len() - ) - past_kv_cache_text = self._truncate_kv_cache_to_compressed_len( - past_kv_cache_text, param.past_kvcache_text.get_compressed_len() + self.gen_pipe.runner.set_inference_params( + index_offset_cond=param.past_kvcache.get_compressed_len(), + index_offset_uncond=param.past_kvcache_text.get_compressed_len(), + cfg_interval=param.cfg_interval, + cfg_scale=param.guidance_scale, + cfg_norm=CfgNormType(param.cfg_norm).as_str(), + timestep_shift=param.timestep_shift, ) - self.gen_pipe.runner.set_kvcache_t2i(past_kv_cache, past_kv_cache_text) + self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text) image = self.gen_pipe.generate( seed=param.seed, - task="t2i", save_result_path="", # 返回base64,不需要指定路径了 target_shape=[param.height, param.width], # Height, Width ) @@ -84,19 +86,9 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams return [image] async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams): - past_kv_cache = self._truncate_kv_cache_to_compressed_len( - past_kv_cache, param.past_kvcache.get_compressed_len() - ) - past_kv_cache_text = self._truncate_kv_cache_to_compressed_len( - past_kv_cache_text, param.past_kvcache_text.get_compressed_len() - ) - past_kv_cache_img = self._truncate_kv_cache_to_compressed_len( - past_kv_cache_img, param.past_kvcache_img.get_compressed_len() - ) self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img) image = self.gen_pipe.generate( seed=param.seed, - task="i2i", save_result_path="", # 返回base64,不需要指定路径了 target_shape=[param.height, param.width], # Height, Width ) @@ -164,11 +156,6 @@ async def loop_for_netio_req(self): def clean_up(self): pass - def _truncate_kv_cache_to_compressed_len(self, kv: torch.Tensor, compressed_len: int) -> torch.Tensor: - seq = kv.shape[2] - n = min(compressed_len, seq) - return kv[:, :, :n:, :].contiguous() - def setup_devices(args: StartArgs): devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip() diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py index 9e9f0d99d4..6d08e34a2c 100644 --- a/lightllm/server/x2i_server/past_kv_cache_client.py +++ b/lightllm/server/x2i_server/past_kv_cache_client.py @@ -33,6 +33,7 @@ def __init__(self, only_create_meta_data: bool, init_shm_data: bool): self.free_pages: List[int] = list(range(self.page_num)) self.lock = Lock() self.cond = Condition(self.lock) + print("PastKVCacheClient init, page num: ", self.page_num, flush=True) if not only_create_meta_data: if init_shm_data: @@ -86,12 +87,13 @@ def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optio # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2,P, S, H // 2, D) -> (L, 2, P * S, H // 2, D) kv = ( self.cpu_kv_cache_tensor[page_indexes] + .contiguous() .view(P, L, S, 2, H // 2, D) .permute(1, 3, 0, 2, 4, 5) .contiguous() .view(L, 2, P * S, H // 2, D) ) - return kv + return kv[:, :, :token_num, :, :].contiguous() def _create_shm_cpu_kv_cache(self): shm_ptr = create_shm_kv_cache_ptr( From 539576b82dba7f0e7cb8ab9e924fcfedf0b61bc9 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 8 Apr 2026 12:49:19 +0000 Subject: [PATCH 12/41] fix x2v acc, because of seed --- lightllm/server/x2i_server/manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index 5865d47746..b4921b69ee 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -90,7 +90,7 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams ) self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text) image = self.gen_pipe.generate( - seed=param.seed, + seed=param.seed + param.past_kvcache.img_len, save_result_path="", # 返回base64,不需要指定路径了 target_shape=[param.height, param.width], # Height, Width ) @@ -112,7 +112,7 @@ async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_i ) self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img) image = self.gen_pipe.generate( - seed=param.seed, + seed=param.seed + param.past_kvcache_img.img_len, save_result_path="", # 返回base64,不需要指定路径了 target_shape=[param.height, param.width], # Height, Width ) From 29dda22014c1ed48fd66b9dad92e4c1cc50c6c17 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 8 Apr 2026 13:36:20 +0000 Subject: [PATCH 13/41] fix chat template --- lightllm/server/build_prompt.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 7f16d519a2..a33e35e815 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -16,7 +16,10 @@ def init_tokenizer(args): if chat_path is not None: with open(chat_path, "r", encoding="utf-8") as f: chat_template_str = f.read() - tokenizer.chat_template = chat_template_str + if hasattr(tokenizer, "tokenizer"): + tokenizer.tokenizer.chat_template = chat_template_str + else: + tokenizer.chat_template = chat_template_str return # 如果 tokenizer 目录下存在chat_template.json, 同时不存在 chat_template.jinja, From e4effb8d5818daee2ba5caf6fd80ab9b0dce3be6 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 8 Apr 2026 14:55:46 +0000 Subject: [PATCH 14/41] smart resize --- lightllm/server/api_models.py | 6 +++++- lightllm/server/api_openai.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 643bba66e3..664a248e80 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -458,7 +458,11 @@ def get_resolution(self): scale = self._size_multiplier[self.image_size] w, h = base - return int(w * scale), int(h * scale) + w, h = int(w * scale), int(h * scale) + from lightllm.models.neo_chat_moe.vision_process import smart_resize + + h, w = smart_resize(h, w, factor=32, min_pixels=512 * 512, max_pixels=2048 * 2048) + return w, h class ChatCompletionRequestV2(ChatCompletionRequest): diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index f446f3e536..e9603f008d 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -636,7 +636,7 @@ async def _chat_completion_image_only( response_images = _message_contents_from_raw_images(images, request.image_config.image_type) chat_message = ChatMessage( role="assistant", - content=prompt, + content="", images=response_images if response_images else None, ) choice = ChatCompletionResponseChoice( From a71817ae83e15af1f137704f452158e60bf4747a Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Thu, 9 Apr 2026 05:50:43 +0000 Subject: [PATCH 15/41] nit. --- lightllm/server/api_openai.py | 2 +- .../router/model_infer/mode_backend/past_kv_cache.py | 3 --- lightllm/server/x2i_server/manager.py | 7 ++++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index e9603f008d..52fe59aa39 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -863,7 +863,7 @@ async def stream_result() -> AsyncGenerator[bytes, None]: prompt, x2i_params, multimodal_params.clone(), request=raw_request ) - if len(images) == 0: + if images is None or len(images) == 0: logger.warning(f"No image generated by x2i: {prompt[-100:]}, exit...") break diff --git a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py index de80da5d01..9035de219b 100644 --- a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py @@ -140,9 +140,6 @@ def update_past_kv_cache_task_states(self): self.past_kv_cache_task.appendleft(task) break - if len(trans_ok_tasks) == 0: - return - ok_tasks_num = torch.tensor(len(trans_ok_tasks)) dist.all_reduce(ok_tasks_num, op=dist.ReduceOp.MIN, group=self.sync_task_status_group) diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index b4921b69ee..d8ed2a4c7d 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -1,11 +1,11 @@ import zmq -import zmq.asyncio import asyncio import uvloop import inspect import setproctitle import pickle import torch +import time import os from typing import List from lightllm.server.core.objs import StartArgs @@ -149,10 +149,12 @@ async def loop_for_fwd(self): images = [] logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {x2i_param}") + start_t = time.time() if is_t2i: images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param) else: images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param) + logger.info(f"generate {len(images)} images done, cost {time.time() - start_t:.2f}s") self.send_to_httpserver.send_pyobj( X2IResponse(request_id=x2i_param.request_id, images=images), protocol=pickle.HIGHEST_PROTOCOL @@ -162,8 +164,7 @@ async def loop_for_fwd(self): self.send_to_httpserver.send_pyobj( X2IResponse(request_id=x2i_param.request_id, images=None), protocol=pickle.HIGHEST_PROTOCOL ) - - logger.error(e) + logger.error(e, exc_info=e) async def loop_for_netio_req(self): while True: From 044158c89008ddbcb46b8831c028d821f717923c Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Thu, 9 Apr 2026 10:05:08 +0000 Subject: [PATCH 16/41] fix device. --- lightllm/server/api_start.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 5ce80bf419..e034d9f437 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -366,6 +366,8 @@ def normal_or_p_d_start(args): ) if origin_devices: os.environ["CUDA_VISIBLE_DEVICES"] = origin_devices + else: + os.environ.pop("CUDA_VISIBLE_DEVICES", None) if args.enable_cpu_cache: from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager From 067a6795302ef15bad1162c7146e78f0e26ea255 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Fri, 10 Apr 2026 09:15:08 +0000 Subject: [PATCH 17/41] support distributed lightx2v. --- lightllm/server/api_start.py | 10 +- lightllm/server/core/objs/start_args_type.py | 2 + lightllm/server/httpserver/manager.py | 4 +- .../server/x2i_server/lightx2v/adapter.py | 145 ++++++++++++++++++ lightllm/server/x2i_server/manager.py | 133 ++++++++++------ 5 files changed, 239 insertions(+), 55 deletions(-) create mode 100644 lightllm/server/x2i_server/lightx2v/adapter.py diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index e034d9f437..6413ec22b0 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -250,7 +250,7 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=12 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports + num=14 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -266,8 +266,10 @@ def normal_or_p_d_start(args): pd_decode_rpyc_port, x2i_port, http_server_port_for_x2i, - ) = can_use_ports[0:12] - can_use_ports = can_use_ports[12:] + x2i_worker_nccl_port, + x2i_worker_task_port, + ) = can_use_ports[0:14] + can_use_ports = can_use_ports[14:] visual_model_tp_ports = [] visual_nccl_ports = [] @@ -294,6 +296,8 @@ def normal_or_p_d_start(args): args.visual_nccl_ports = visual_nccl_ports args.x2i_port = x2i_port args.http_server_port_for_x2i = http_server_port_for_x2i + args.x2i_worker_task_port = x2i_worker_task_port + args.x2i_worker_nccl_port = x2i_worker_nccl_port # 申请在 p d 分离模式下,会用的端口 args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] # p d 分离模式下用于标识节点的id diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 17cdeed06c..ece47a2bf1 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -166,6 +166,8 @@ class StartArgs: http_server_port_for_x2i: int = field(default=None) x2i_server_used_gpus: int = field(default=1) x2i_use_naive_impl: bool = field(default=False) + x2i_worker_task_port: int = field(default=None) + x2i_worker_nccl_port: int = field(default=None) # multi_modal enable_multimodal_x2i: bool = field(default=False) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 02ea465e32..6ca8e86bc6 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -104,7 +104,7 @@ def __init__( self.send_to_x2i = context.socket(zmq.PUSH) self.send_to_x2i.connect(f"{args.zmq_mode}127.0.0.1:{args.x2i_port}") self.recv_from_x2i = context.socket(zmq.PULL) - self.recv_from_x2i.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}") + self.recv_from_x2i.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}") self.req_id_to_x2i_reqs: Dict[int, X2IReqStatus] = {} self.shm_req_manager = ShmReqManager() @@ -477,7 +477,7 @@ async def generation_wrapper(prompt, sample, multimodal, request): else: # call t2i prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt) - logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}") + # logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}") (con_gen, uncon_gen) = await asyncio.gather( *[ generation_wrapper(prompt_condition, sample_params, multimodal_params, request), diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py new file mode 100644 index 0000000000..cd0de51d88 --- /dev/null +++ b/lightllm/server/x2i_server/lightx2v/adapter.py @@ -0,0 +1,145 @@ +import inspect +import torch +import torch.distributed as dist +import zmq +import setproctitle +import asyncio +import os + +from lightllm.server.core.objs import StartArgs +from lightllm.utils.log_utils import init_logger +from lightllm.utils.graceful_utils import graceful_registry +from lightllm.utils.process_check import start_parent_check_thread +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease, CfgNormType +from ..past_kv_cache_client import PastKVCacheClient + +logger = init_logger(__name__) + + +class LightX2VServer: + def __init__(self, args: StartArgs, rank: int, world_size: int): + self.args = args + self.rank = rank + self.world_size = world_size + + context = zmq.Context(2) + + # receive task from manager + self.task_socket = context.socket(zmq.SUB) + self.task_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.x2i_worker_task_port}") + self.task_socket.setsockopt(zmq.SUBSCRIBE, b"") + + if self.rank == 0: + # send result back + self.result_socket = context.socket(zmq.PUSH) + self.result_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.http_server_port_for_x2i}") + + self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=False) + torch.cuda.set_device(rank) + self._init_pipeline() + + def _init_pipeline(self): + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(self.args.x2i_worker_nccl_port) + os.environ["RANK"] = str(self.rank) + os.environ["WORLD_SIZE"] = str(self.world_size) + + from lightx2v import LightX2VPipeline + + self.pipe = LightX2VPipeline( + model_path=self.args.model_dir, + model_cls="neopp", + support_tasks=["t2i", "i2i"], + ) + self.pipe.create_generator(config_json=self.args.x2v_gen_model_config) + self.pipe.modify_config({ + "load_kv_cache_in_pipeline_for_debug": False, + "save_result_for_debug": False}) + + async def run(self): + while True: + param: X2IParams = self.task_socket.recv_pyobj() + + try: + images = self._process(param) + + if self.rank == 0: + self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=images)) + + except Exception as e: + logger.error(f"Error processing request {param.request_id}: {str(e)}", exc_info=e) + if self.rank == 0: + self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=None)) + + def _process(self, param: X2IParams): + is_t2i = param.past_kvcache_img.is_empty() + + self.pipe.runner.set_inference_params( + index_offset_cond=param.past_kvcache.get_compressed_len(), + index_offset_uncond=param.past_kvcache_text.get_compressed_len(), + cfg_interval=param.cfg_interval, + cfg_scale=param.guidance_scale, + cfg_norm=CfgNormType(param.cfg_norm).as_str(), + timestep_shift=param.timestep_shift, + ) + past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i( + param.past_kvcache.get_all(), param.past_kvcache.token_len) + past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i( + param.past_kvcache_text.get_all(), param.past_kvcache_text.token_len) + past_kv_cache_img = None + if not is_t2i: + past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i( + param.past_kvcache_img.get_all(), param.past_kvcache_img.token_len) + + dist.barrier() # ensure all workers have got the kv cache before generation starts + + if self.rank == 0: + # release + self.result_socket.send_pyobj(X2ICacheRelease(request_id=param.request_id)) + + if is_t2i: + self.pipe.runner.set_kvcache( + past_kv_cache, + past_kv_cache_text, + ) + else: + self.pipe.runner.set_kvcache_i2i( + past_kv_cache, + past_kv_cache_text, + past_kv_cache_img, + ) + image = self.pipe.generate( + seed=param.seed + param.past_kvcache.img_len, + save_result_path="", + target_shape=[param.height, param.width], + ) + + return [image] + + +def start_x2v_process(args: StartArgs, rank: int, world_size: int, pipe_writer): + + # 注册graceful 退出的处理 + graceful_registry(inspect.currentframe().f_code.co_name) + setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::x2v_server_{rank}") + start_parent_check_thread() + + try: + x2v_server = LightX2VServer(args=args, rank=rank, world_size=world_size) + except Exception as e: + logger.exception(str(e), exc_info=e) + raise e + + pipe_writer.send("init ok") + + def handle_exception(loop, context): + logger.exception(f"X2VServer Caught exception: {str(context)}") + + loop = asyncio.new_event_loop() + loop.set_exception_handler(handle_exception) + asyncio.set_event_loop(loop) + + loop.run_until_complete(x2v_server.run()) + + return diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index d8ed2a4c7d..8b86e12030 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -6,6 +6,7 @@ import pickle import torch import time +import multiprocessing as mp import os from typing import List from lightllm.server.core.objs import StartArgs @@ -17,6 +18,7 @@ from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease, CfgNormType from lightllm.utils.dist_utils import set_current_device_id +from lightllm.utils.start_utils import start_submodule_processes from .past_kv_cache_client import PastKVCacheClient logger = init_logger(__name__) @@ -29,8 +31,23 @@ 3. call llm gen to obtain past key values 4. call x2v to generate images and pass the key values to it 5. return the generated images. -""" + +-------------------+ + | X2IManager | + +---------+---------+ + | + (broadcast) + | + +------+------+------+ + | | | + Worker0 Worker1 ... + (rank0) (rank1) + | | + +------ allreduce / sync ----+ + | + only rank0 + returns result +""" class X2IManager: def __init__( @@ -40,40 +57,51 @@ def __init__( context = zmq.Context(2) self.args = args + # from http server self.zmq_recv_socket = context.socket(zmq.PULL) self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.x2i_port}") + # to http server self.send_to_httpserver = context.socket(zmq.PUSH) - self.send_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}") + self.send_to_httpserver.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}") + + self.use_naive_x2i = args.x2i_use_naive_impl + self.world_size = args.x2i_server_used_gpus + + if not self.use_naive_x2i and self.world_size > 1: + # send to workers + self.worker_pub = context.socket(zmq.PUB) + self.worker_pub.bind(f"{args.zmq_mode}127.0.0.1:{args.x2i_worker_task_port}") self.waiting_reqs: List[X2IParams] = [] self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True) - self.use_naive_x2i = args.x2i_use_naive_impl async def wait_to_model_ready(self): - if self.use_naive_x2i: - from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I - self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device()) - return - - from lightx2v import LightX2VPipeline - - self.gen_pipe = LightX2VPipeline( - model_path=self.args.model_dir, - model_cls="neopp", - support_tasks=["t2i", "i2i"], - ) - self.gen_pipe.create_generator( - config_json=self.args.x2v_gen_model_config, - ) - self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False}) - - # from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I - # self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device()) - pass + if self.world_size <= 1: + if self.use_naive_x2i: + from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I + self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device()) + else: + from lightx2v import LightX2VPipeline + + self.gen_pipe = LightX2VPipeline( + model_path=self.args.model_dir, + model_cls="neopp", + support_tasks=["t2i", "i2i"], + ) + self.gen_pipe.create_generator( + config_json=self.args.x2v_gen_model_config, + ) + self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False}) + else: + # distribted x2v + from lightllm.server.x2i_server.lightx2v.adapter import start_x2v_process + funcs = [start_x2v_process] * self.world_size + args = [(self.args, rank, self.world_size) for rank in range(self.world_size)] + start_submodule_processes(funcs, args) async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams): if self.use_naive_x2i: @@ -127,38 +155,42 @@ async def loop_for_fwd(self): x2i_param = self.waiting_reqs.pop(0) - past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i( - x2i_param.past_kvcache.get_all(), x2i_param.past_kvcache.token_len, self.use_naive_x2i - ) - - past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i( - x2i_param.past_kvcache_text.get_all(), x2i_param.past_kvcache_text.token_len, self.use_naive_x2i - ) - is_t2i = x2i_param.past_kvcache_img.is_empty() + if not self.use_naive_x2i and self.world_size > 1: + # broadcast to workers + self.worker_pub.send_pyobj(x2i_param, protocol=pickle.HIGHEST_PROTOCOL) + else: + past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i( + x2i_param.past_kvcache.get_all(), x2i_param.past_kvcache.token_len, self.use_naive_x2i + ) - past_kv_cache_img = None - if not is_t2i: # t2i - past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i( - x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len, self.use_naive_x2i + past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i( + x2i_param.past_kvcache_text.get_all(), x2i_param.past_kvcache_text.token_len, self.use_naive_x2i ) + is_t2i = x2i_param.past_kvcache_img.is_empty() - # release - self.send_to_httpserver.send_pyobj( - X2ICacheRelease(request_id=x2i_param.request_id), protocol=pickle.HIGHEST_PROTOCOL - ) + past_kv_cache_img = None + if not is_t2i: # t2i + past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i( + x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len, self.use_naive_x2i + ) - images = [] - logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {x2i_param}") - start_t = time.time() - if is_t2i: - images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param) - else: - images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param) - logger.info(f"generate {len(images)} images done, cost {time.time() - start_t:.2f}s") + # release + self.send_to_httpserver.send_pyobj( + X2ICacheRelease(request_id=x2i_param.request_id), protocol=pickle.HIGHEST_PROTOCOL + ) - self.send_to_httpserver.send_pyobj( - X2IResponse(request_id=x2i_param.request_id, images=images), protocol=pickle.HIGHEST_PROTOCOL - ) + images = [] + logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {x2i_param}") + start_t = time.time() + if is_t2i: + images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param) + else: + images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param) + logger.info(f"generate {len(images)} images done, cost {time.time() - start_t:.2f}s") + + self.send_to_httpserver.send_pyobj( + X2IResponse(request_id=x2i_param.request_id, images=images), protocol=pickle.HIGHEST_PROTOCOL + ) except Exception as e: self.send_to_httpserver.send_pyobj( @@ -190,6 +222,7 @@ def setup_devices(args: StartArgs): devices = [int(x.strip()) for x in devices.split(",") if x.strip()] llm_need_gpus = args.tp * args.dp + # llm_need_gpus = 0 x2i_need_gpus = args.x2i_server_used_gpus if len(devices) < llm_need_gpus + x2i_need_gpus: raise ValueError(f"devices {devices} not enough, need {llm_need_gpus} and {x2i_need_gpus}") From 6cf6a101138410d5d9d565e08d505e4574a45151 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 13 Apr 2026 03:22:44 +0000 Subject: [PATCH 18/41] change task distribute from pub/sub to push/pull --- .../server/x2i_server/lightx2v/adapter.py | 35 ++++++++++++------- lightllm/server/x2i_server/manager.py | 18 +--------- 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py index cd0de51d88..47472a7c53 100644 --- a/lightllm/server/x2i_server/lightx2v/adapter.py +++ b/lightllm/server/x2i_server/lightx2v/adapter.py @@ -2,6 +2,7 @@ import torch import torch.distributed as dist import zmq +import zmq.asyncio import setproctitle import asyncio import os @@ -23,14 +24,13 @@ def __init__(self, args: StartArgs, rank: int, world_size: int): self.rank = rank self.world_size = world_size - context = zmq.Context(2) - # receive task from manager - self.task_socket = context.socket(zmq.SUB) - self.task_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.x2i_worker_task_port}") - self.task_socket.setsockopt(zmq.SUBSCRIBE, b"") - if self.rank == 0: + context = zmq.asyncio.Context(2) + self.task_socket = context.socket(zmq.PULL) + self.task_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.x2i_worker_task_port}") + # self.task_socket.setsockopt(zmq.SUBSCRIBE, b"") + # send result back self.result_socket = context.socket(zmq.PUSH) self.result_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.http_server_port_for_x2i}") @@ -39,6 +39,8 @@ def __init__(self, args: StartArgs, rank: int, world_size: int): torch.cuda.set_device(rank) self._init_pipeline() + self.task_dist_group = dist.new_group(backend="gloo") + def _init_pipeline(self): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(self.args.x2i_worker_nccl_port) @@ -59,20 +61,29 @@ def _init_pipeline(self): async def run(self): while True: - param: X2IParams = self.task_socket.recv_pyobj() try: - images = self._process(param) + if self.rank == 0: + param: X2IParams = await self.task_socket.recv_pyobj() + dist.broadcast_object_list([param], src=0, group=self.task_dist_group) + else: + params = [None] + dist.broadcast_object_list(params, src=0, group=self.task_dist_group) + param: X2IParams = params[0] + + assert param is not None, "Received None param in x2v worker, this should not happen." + + images = await self._process(param) if self.rank == 0: - self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=images)) + await self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=images)) except Exception as e: logger.error(f"Error processing request {param.request_id}: {str(e)}", exc_info=e) if self.rank == 0: - self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=None)) + await self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=None)) - def _process(self, param: X2IParams): + async def _process(self, param: X2IParams): is_t2i = param.past_kvcache_img.is_empty() self.pipe.runner.set_inference_params( @@ -96,7 +107,7 @@ def _process(self, param: X2IParams): if self.rank == 0: # release - self.result_socket.send_pyobj(X2ICacheRelease(request_id=param.request_id)) + await self.result_socket.send_pyobj(X2ICacheRelease(request_id=param.request_id)) if is_t2i: self.pipe.runner.set_kvcache( diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index 8b86e12030..f715df74b5 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -31,22 +31,6 @@ 3. call llm gen to obtain past key values 4. call x2v to generate images and pass the key values to it 5. return the generated images. - - +-------------------+ - | X2IManager | - +---------+---------+ - | - (broadcast) - | - +------+------+------+ - | | | - Worker0 Worker1 ... - (rank0) (rank1) - | | - +------ allreduce / sync ----+ - | - only rank0 - returns result """ class X2IManager: @@ -70,7 +54,7 @@ def __init__( if not self.use_naive_x2i and self.world_size > 1: # send to workers - self.worker_pub = context.socket(zmq.PUB) + self.worker_pub = context.socket(zmq.PUSH) self.worker_pub.bind(f"{args.zmq_mode}127.0.0.1:{args.x2i_worker_task_port}") self.waiting_reqs: List[X2IParams] = [] From 29850a61a9e4132fc4e226b437c07decad7adda3 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 13 Apr 2026 05:43:11 +0000 Subject: [PATCH 19/41] use poller. --- lightllm/server/httpserver/manager.py | 9 ++++++++- lightllm/server/x2i_server/lightx2v/adapter.py | 1 - 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 6ca8e86bc6..4a4bf937a3 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -842,10 +842,17 @@ async def recycle_resource_loop(self): return async def loop_for_x2i(self): + poller = zmq.asyncio.Poller() + poller.register(self.recv_from_x2i, zmq.POLLIN) while True: try: - recv_obj = await asyncio.wait_for(self.recv_from_x2i.recv_pyobj(), timeout=0.05) + events = dict(await poller.poll(timeout=50)) + + if self.recv_from_x2i in events: + recv_obj = await self.recv_from_x2i.recv_pyobj() + else: + continue if isinstance(recv_obj, X2ICacheRelease): status = self.req_id_to_x2i_reqs[recv_obj.request_id] diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py index 47472a7c53..76900a92f2 100644 --- a/lightllm/server/x2i_server/lightx2v/adapter.py +++ b/lightllm/server/x2i_server/lightx2v/adapter.py @@ -29,7 +29,6 @@ def __init__(self, args: StartArgs, rank: int, world_size: int): context = zmq.asyncio.Context(2) self.task_socket = context.socket(zmq.PULL) self.task_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.x2i_worker_task_port}") - # self.task_socket.setsockopt(zmq.SUBSCRIBE, b"") # send result back self.result_socket = context.socket(zmq.PUSH) From 40a16b0e12cb7b1f9ede824dda1dc82912bbd1f2 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 13 Apr 2026 06:38:59 +0000 Subject: [PATCH 20/41] add num_images for x2v --- lightllm/server/x2i_server/manager.py | 51 +++++++++++++++++---------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index f715df74b5..28462da8f1 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -33,6 +33,7 @@ 5. return the generated images. """ + class X2IManager: def __init__( self, @@ -61,12 +62,12 @@ def __init__( self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True) - async def wait_to_model_ready(self): if self.world_size <= 1: if self.use_naive_x2i: from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I + self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device()) else: from lightx2v import LightX2VPipeline @@ -79,10 +80,13 @@ async def wait_to_model_ready(self): self.gen_pipe.create_generator( config_json=self.args.x2v_gen_model_config, ) - self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False}) + self.gen_pipe.modify_config( + {"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False} + ) else: # distribted x2v from lightllm.server.x2i_server.lightx2v.adapter import start_x2v_process + funcs = [start_x2v_process] * self.world_size args = [(self.args, rank, self.world_size) for rank in range(self.world_size)] start_submodule_processes(funcs, args) @@ -100,14 +104,16 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams cfg_norm=CfgNormType(param.cfg_norm).as_str(), timestep_shift=param.timestep_shift, ) - self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text) - image = self.gen_pipe.generate( - seed=param.seed + param.past_kvcache.img_len, - save_result_path="", # 返回base64,不需要指定路径了 - target_shape=[param.height, param.width], # Height, Width - ) - # images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param) - return [image] + images = [] + for i in range(param.num_images): + self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text) + image = self.gen_pipe.generate( + seed=param.seed + param.past_kvcache.img_len + i, + save_result_path="", # 返回base64,不需要指定路径了 + target_shape=[param.height, param.width], # Height, Width + ) + images.append(image) + return images async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams): if self.use_naive_x2i: @@ -122,13 +128,16 @@ async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_i cfg_norm=CfgNormType(param.cfg_norm).as_str(), timestep_shift=param.timestep_shift, ) - self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img) - image = self.gen_pipe.generate( - seed=param.seed + param.past_kvcache_img.img_len, - save_result_path="", # 返回base64,不需要指定路径了 - target_shape=[param.height, param.width], # Height, Width - ) - return [image] + images = [] + for i in range(param.num_images): + self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img) + image = self.gen_pipe.generate( + seed=param.seed + param.past_kvcache_img.img_len + i, + save_result_path="", # 返回base64,不需要指定路径了 + target_shape=[param.height, param.width], # Height, Width + ) + images.append(image) + return images async def loop_for_fwd(self): while True: @@ -155,7 +164,9 @@ async def loop_for_fwd(self): past_kv_cache_img = None if not is_t2i: # t2i past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i( - x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len, self.use_naive_x2i + x2i_param.past_kvcache_img.get_all(), + x2i_param.past_kvcache_img.token_len, + self.use_naive_x2i, ) # release @@ -169,7 +180,9 @@ async def loop_for_fwd(self): if is_t2i: images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param) else: - images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param) + images = await self.it2i_generate( + past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param + ) logger.info(f"generate {len(images)} images done, cost {time.time() - start_t:.2f}s") self.send_to_httpserver.send_pyobj( From af5c710a4aad463a6c8e1f025b33ba6c36925803 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Mon, 13 Apr 2026 11:18:39 +0000 Subject: [PATCH 21/41] fixup. --- lightllm/server/api_cli.py | 7 +++++++ lightllm/server/core/objs/start_args_type.py | 1 + lightllm/server/httpserver/manager.py | 11 ++++++++--- lightllm/server/x2i_server/lightx2v/adapter.py | 3 ++- lightllm/server/x2i_server/manager.py | 3 +-- 5 files changed, 19 insertions(+), 6 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 34b7226470..6efb2a8843 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -313,6 +313,13 @@ def make_argument_parser() -> argparse.ArgumentParser: default=1, help="Number of GPUs to use for x2i server (requird --enable_multimodal_x2i).", ) + parser.add_argument( + "--x2i_server_deploy_mode", + type=str, + choices=["colocate", "separate"], + default="colocate", + help="Deployment mode for the x2i server. 'colocate' means the x2i server will run on the same gpus as the llm server, ", + ) parser.add_argument( "--x2i_use_naive_impl", action="store_true", diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index ece47a2bf1..5599a831ec 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -165,6 +165,7 @@ class StartArgs: x2i_port: int = field(default=None) http_server_port_for_x2i: int = field(default=None) x2i_server_used_gpus: int = field(default=1) + x2i_server_deploy_mode: str = field(default="colocate") x2i_use_naive_impl: bool = field(default=False) x2i_worker_task_port: int = field(default=None) x2i_worker_nccl_port: int = field(default=None) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 4a4bf937a3..0304ecea91 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -454,6 +454,9 @@ async def generation_wrapper(prompt, sample, multimodal, request): # 1. construct 3 or 2 generate based on the multimodel_parmas sample_params = SamplingParams() sample_params.init(self.tokenizer, **{"img_gen_prefill": True}) + sample_params1 = SamplingParams() + sample_params1.init(self.tokenizer, **{"img_gen_prefill": True}) + img_len = len(multimodal_params.images) image_guidance_scale = generation_params.image_guidance_scale @@ -466,11 +469,13 @@ async def generation_wrapper(prompt, sample, multimodal, request): prompt_condition, prompt_text_uncondition, prompt_img_uncondition = self.tokenizer.get_query_for_it2i( prompt ) + sample_params2 = SamplingParams() + sample_params2.init(self.tokenizer, **{"img_gen_prefill": True}) (con_gen, text_uncon_gen, img_uncon_gen) = await asyncio.gather( *[ generation_wrapper(prompt_condition, sample_params, multimodal_params, request), - generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params.clone(), request), - generation_wrapper(prompt_img_uncondition, sample_params, MultimodalParams(), request), + generation_wrapper(prompt_text_uncondition, sample_params1, multimodal_params.clone(), request), + generation_wrapper(prompt_img_uncondition, sample_params2, MultimodalParams(), request), ] ) generation_params.update_it2i(con_gen, text_uncon_gen, img_uncon_gen) @@ -481,7 +486,7 @@ async def generation_wrapper(prompt, sample, multimodal, request): (con_gen, uncon_gen) = await asyncio.gather( *[ generation_wrapper(prompt_condition, sample_params, multimodal_params, request), - generation_wrapper(prompt_uncondition, sample_params, multimodal_params.clone(), request), + generation_wrapper(prompt_uncondition, sample_params1, multimodal_params.clone(), request), ] ) generation_params.update_t2i(con_gen, uncon_gen) diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py index 76900a92f2..b55d0d6337 100644 --- a/lightllm/server/x2i_server/lightx2v/adapter.py +++ b/lightllm/server/x2i_server/lightx2v/adapter.py @@ -1,3 +1,4 @@ +import datetime import inspect import torch import torch.distributed as dist @@ -38,7 +39,7 @@ def __init__(self, args: StartArgs, rank: int, world_size: int): torch.cuda.set_device(rank) self._init_pipeline() - self.task_dist_group = dist.new_group(backend="gloo") + self.task_dist_group = dist.new_group(backend="gloo", timeout=datetime.timedelta(days=30)) def _init_pipeline(self): os.environ["MASTER_ADDR"] = "127.0.0.1" diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index 28462da8f1..aef9b0ab88 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -218,8 +218,7 @@ def setup_devices(args: StartArgs): else: devices = [int(x.strip()) for x in devices.split(",") if x.strip()] - llm_need_gpus = args.tp * args.dp - # llm_need_gpus = 0 + llm_need_gpus = 0 if args.x2i_server_deploy_mode == "colocate" else args.tp * args.dp x2i_need_gpus = args.x2i_server_used_gpus if len(devices) < llm_need_gpus + x2i_need_gpus: raise ValueError(f"devices {devices} not enough, need {llm_need_gpus} and {x2i_need_gpus}") From 2e7c525a8d9fb454d34a7f6af209bd69b73a5869 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 13 Apr 2026 12:17:49 +0000 Subject: [PATCH 22/41] enable thinking & input_image num --- lightllm/models/neo_chat_moe/model.py | 6 ++++-- lightllm/server/api_models.py | 2 +- lightllm/server/api_openai.py | 4 ++-- lightllm/server/httpserver/manager.py | 9 +++++++-- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index 648d696877..dba8965512 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -154,11 +154,13 @@ def get_query_for_it2i(self, prompt: str): question_img_uncondition = self._build_t2i_query("") return query_condition, query_text_uncondition, question_img_uncondition - def get_query_for_t2i(self, prompt: str): + def get_query_for_t2i(self, prompt: str, input_image_num: int = 0): # prompt is already applied image_len = prompt.count(IMG_TOKEN) query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt - query_uncondition = self._build_t2i_query("", thinking_content=IMG_TOKEN * image_len) + query_uncondition = self._build_t2i_query( + IMG_TOKEN * input_image_num, thinking_content=IMG_TOKEN * (image_len - input_image_num) + ) return query_condition, query_uncondition diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index 664a248e80..c2fce7dd95 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -418,7 +418,7 @@ class ImageConfig(BaseModel): "4:5": (896, 1152), "5:4": (1152, 896), "9:16": (768, 1344), - "16:9": (1344, 768), + "16:9": (1920, 1080), "21:9": (1536, 672), } diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 52fe59aa39..dbbfbb5528 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -668,7 +668,6 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request if request.chat_template_kwargs is None: request.chat_template_kwargs = {} - request.chat_template_kwargs.update({"enable_thinking": False}) chat_request: ChatCompletionRequest = ChatCompletionRequest(**request.model_dump()) @@ -689,6 +688,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request created_time = int(time.time()) prompt, sampling_params, multimodal_params = await _get_text_generator_input(chat_request) + input_image_num = len(multimodal_params.images) x2i_params = X2IParams() x2i_params.init_from_image_config(request.image_config) @@ -732,7 +732,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request if need_call_x2i: prompt += output_chunk images = await g_objs.httpserver_manager.generate_image( - prompt, x2i_params, multimodal_params.clone(), request=raw_request + prompt, x2i_params, multimodal_params.clone(), request=raw_request, input_image_num=input_image_num ) if len(images) == 0: logger.warning(f"No image generated by x2i: {prompt[-100:]}, exit...") diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 0304ecea91..31d445483b 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -435,7 +435,12 @@ async def generate( return async def generate_image( - self, prompt: str, generation_params: X2IParams, multimodal_params: MultimodalParams, request: Request + self, + prompt: str, + generation_params: X2IParams, + multimodal_params: MultimodalParams, + request: Request, + input_image_num: int = 0, ): generate_req_ids = [] @@ -481,7 +486,7 @@ async def generation_wrapper(prompt, sample, multimodal, request): generation_params.update_it2i(con_gen, text_uncon_gen, img_uncon_gen) else: # call t2i - prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt) + prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt, input_image_num) # logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}") (con_gen, uncon_gen) = await asyncio.gather( *[ From c9d19d011e883aff061dc3f49d22b2c6908083d5 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 13 Apr 2026 13:14:52 +0000 Subject: [PATCH 23/41] keep the same resolution for it2i --- lightllm/server/api_openai.py | 3 ++- lightllm/server/core/objs/x2i_params.py | 12 ++++++++++++ lightllm/server/httpserver/manager.py | 5 +++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index dbbfbb5528..55112d7a4b 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -630,8 +630,9 @@ async def _chat_completion_image_only( from .api_http import g_objs created_time = int(time.time()) + input_image_num = len(multimodal_params.images) images = await g_objs.httpserver_manager.generate_image( - prompt, x2i_params, multimodal_params.clone(), request=raw_request + prompt, x2i_params, multimodal_params.clone(), request=raw_request, input_image_num=input_image_num ) response_images = _message_contents_from_raw_images(images, request.image_config.image_type) chat_message = ChatMessage( diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py index e0f1bdb30c..281e5f66d2 100644 --- a/lightllm/server/core/objs/x2i_params.py +++ b/lightllm/server/core/objs/x2i_params.py @@ -78,6 +78,7 @@ def _get(key, default): self.past_kvcache_img = PastKVCachePageList() self.total_prompt_tokens = 0 self.request_id = 0 + self.has_updated_hw = False def init_from_image_config(self, image_config: Any) -> None: """从 HTTP `image_config`(api_models.ImageConfig)填充,与 `init(**kwargs)` 共用默认值逻辑。""" @@ -104,6 +105,17 @@ def init_from_image_config(self, image_config: Any) -> None: break self.init(**kwargs) + def update_hw(self, width: int, height: int): + if self.has_updated_hw: + return + from lightllm.models.neo_chat_moe.vision_process import smart_resize + + h, w = smart_resize(height, width, factor=32, min_pixels=512 * 512, max_pixels=2048 * 2048) + self.width = w + self.height = h + self.has_updated_hw = True + return + def update(self, past_kv: PastKVCachePageList, meta: Dict): item: PastKVCacheItem = meta.get("kv_cache_item") past_kv.token_len = item.token_len diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 31d445483b..99ed011cb1 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -495,6 +495,11 @@ async def generation_wrapper(prompt, sample, multimodal, request): ] ) generation_params.update_t2i(con_gen, uncon_gen) + if input_image_num > 0: + # for it2i, the output image size is the same as the input image size + generation_params.update_hw( + multimodal_params.images[0].image_w, multimodal_params.images[0].image_h + ) # use the first request id as the gen image request id x2i_req_id = generate_req_ids[0] generation_params.request_id = x2i_req_id From d0967b70dbc6ef9ba9a5adb6fe1656ad3e52473c Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Tue, 14 Apr 2026 10:58:57 +0000 Subject: [PATCH 24/41] fixup. --- .../triton_kernel/kv_cache_offload.py | 2 +- .../server/router/model_infer/infer_batch.py | 4 ++ .../model_infer/mode_backend/past_kv_cache.py | 39 +++++++++++++++---- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py index 686c0c1790..01d32e908c 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py +++ b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py @@ -916,7 +916,7 @@ def offload_gpu_kv_to_cpu_for_x2i( assert gpu_heads == 1 assert cpu_heads == 1 - need_offload == tp_index == 0 + need_offload = tp_index == 0 cpu_k_start_head_index = 0 cpu_k_head_num = 1 gpu_k_start_head_index = 0 diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 0088c60d65..898a48f87a 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -590,6 +590,10 @@ def handle( # detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。 shm_req.shm_cur_output_len = self.output_len + if shm_req.sample_params.img_gen_prefill: + # img gen prefill 需要等待 kv cache 卸载到 cpu 后才更新detokenization需要的信息 + return + if finish_status.is_finished(): shm_req.finish_token_index = shm_req.input_len + self.output_len - 1 shm_req.finish_status = req_obj.finish_status diff --git a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py index 9035de219b..9d0a8d4dd4 100644 --- a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py @@ -20,17 +20,19 @@ class TransTask: req_obj: InferReq sync_event: torch.cuda.Event - + buffer_index: int class PastKVCacheModule(object): def __init__(self, backend): from .base_backend import ModeBackend self.backend: ModeBackend = backend self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=False) - self.page_index_buffer = torch.empty((LIGHTLLM_TOKEN_HASH_LIST_SIZE * 2,), dtype=torch.int32, device="cuda") + self.page_index_buffer = torch.empty((1024 * LIGHTLLM_TOKEN_HASH_LIST_SIZE,), dtype=torch.int32, device="cuda") + self.page_index_buffer_free_index = list(range(1024)) self.past_kv_cache_task: Deque[TransTask] = deque() self.sync_task_status_group = create_new_group_for_current_dp("gloo") + @lru_cache() def need_sync_compute_stream(self) -> bool: """ @@ -51,7 +53,6 @@ def need_sync_compute_stream(self) -> bool: logger.info("PastKVCacheModule: no need sync compute stream.") return False - def offload_finished_reqs_to_past_kv_cache(self, finished_reqs: List[InferReq]) -> List[InferReq]: """ Offload the finished reqs to past kv cache, and return the truly finished reqs that can be freed in infer batch. @@ -70,28 +71,37 @@ def offload_finished_reqs_to_past_kv_cache(self, finished_reqs: List[InferReq]) if req.past_kv_cache_task_status.is_running(): continue - assert req.past_kv_cache_task_status.is_not_started() + assert req.past_kv_cache_task_status.is_not_started(), \ + f"req {req.req_id} has invalid past kv cache task status {req.past_kv_cache_task_status}" if self.need_sync_compute_stream(): g_infer_context.get_overlap_stream().synchronize() trans_task = self._start_kv_cache_offload(req=req) - assert trans_task is not None + assert trans_task is not None, f"req {req.req_id} start kv cache offload failed" self.past_kv_cache_task.append(trans_task) - return true_finished_reqs def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]: with torch.cuda.stream(g_infer_context.get_cpu_kv_cache_stream()): + if len(self.page_index_buffer_free_index) == 0: + raise RuntimeError("No free page index for offloading past kv cache to CPU.") + + assert req.shm_req.past_kv_cache_page_indexes.size <= LIGHTLLM_TOKEN_HASH_LIST_SIZE + + free_index = self.page_index_buffer_free_index.pop(0) + start = free_index * LIGHTLLM_TOKEN_HASH_LIST_SIZE + end = start + req.shm_req.past_kv_cache_page_indexes.size + page_indexes = torch.tensor(req.shm_req.past_kv_cache_page_indexes.get_all(), dtype=torch.int32, device='cpu', pin_memory=True) num_tokens = req.shm_req.input_len assert req.cur_kv_len >= num_tokens assert num_tokens <= len(page_indexes) * self.past_kv_cache_client.token_page_size - cuda_page_indexes = self.page_index_buffer[:len(page_indexes)] + cuda_page_indexes = self.page_index_buffer[start:end] cuda_page_indexes.copy_(page_indexes) token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0: num_tokens] @@ -102,7 +112,7 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]: cpu_cache_meta = self.past_kv_cache_client.kv_cache_tensor_meta cpu_kv_cache = self.past_kv_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0:cpu_cache_meta.head_dim] cpu_kv_cache_scale = self.past_kv_cache_client.cpu_kv_cache_tensor[ - :, :, :, :, cpu_cache_meta.head_dim + :, :, :, :, cpu_cache_meta.head_dim : ].view(mem_manager.scale_buffer.dtype) gpu_kv_cache_scale = mem_manager.scale_buffer else: @@ -124,10 +134,12 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]: ) sync_event = torch.cuda.Event() sync_event.record() + # sync_event.synchronize() req.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING return TransTask( req_obj=req, sync_event=sync_event, + buffer_index=free_index ) def update_past_kv_cache_task_states(self): @@ -148,3 +160,14 @@ def update_past_kv_cache_task_states(self): self.past_kv_cache_task.extendleft(reversed(unfinished)) for task in finished: task.req_obj.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED + self.page_index_buffer_free_index.append(task.buffer_index) + + if self.backend.is_master_in_dp: + shm_req = task.req_obj.shm_req + assert task.req_obj.finish_status.is_finished() + shm_req.finish_token_index = shm_req.input_len + shm_req.shm_cur_output_len - 1 + shm_req.finish_status = task.req_obj.finish_status + shm_req.candetoken_out_len = shm_req.shm_cur_output_len + else: + if len(trans_ok_tasks) > 0: + self.past_kv_cache_task.extendleft(reversed(trans_ok_tasks)) \ No newline at end of file From f6ae0d1974a7194e607adc8ec6eac9291f8d1460 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Wed, 15 Apr 2026 06:34:36 +0000 Subject: [PATCH 25/41] workaround for illegal memory access. --- .../server/router/model_infer/mode_backend/past_kv_cache.py | 1 + lightllm/server/x2i_server/lightx2v/adapter.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py index 9d0a8d4dd4..e6457ded9a 100644 --- a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py @@ -134,6 +134,7 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]: ) sync_event = torch.cuda.Event() sync_event.record() + sync_event.wait(g_infer_context.get_overlap_stream()) # sync_event.synchronize() req.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING return TransTask( diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py index b55d0d6337..b738b00d7d 100644 --- a/lightllm/server/x2i_server/lightx2v/adapter.py +++ b/lightllm/server/x2i_server/lightx2v/adapter.py @@ -109,6 +109,8 @@ async def _process(self, param: X2IParams): # release await self.result_socket.send_pyobj(X2ICacheRelease(request_id=param.request_id)) + logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {param}") + if is_t2i: self.pipe.runner.set_kvcache( past_kv_cache, From 20439fefdeb56159816e96fe30741266253b6e1f Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 20 Apr 2026 08:05:57 +0000 Subject: [PATCH 26/41] fix attention --- .../context_attention_fwd_neo.py | 2 +- .../test_context_attention_fwd_neo.py | 345 ++++++++++++++++++ 2 files changed, 346 insertions(+), 1 deletion(-) create mode 100644 unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py index 74ff82cae4..a745b7c9da 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -103,7 +103,7 @@ def _fwd_kernel( qk += tl.dot(q, k) # mask: causal OR same gid (only possible inside NEW part) - mask = (q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None] + mask = ((q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None]) & k_valid[None, :] qk = tl.where(mask, qk * sm_scale, -1.0e8) # online softmax diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py new file mode 100644 index 0000000000..4abfde5a83 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py @@ -0,0 +1,345 @@ +"""Unit test for ``context_attention_fwd_neo``. + +Torch reference expresses the *semantics* of the attention, not the kernel's +internal block structure — it has no notion of BLOCK_N / BLOCK_M. For each +batch element we gather K/V for the whole request (prompt + new tokens) via +``req_to_token_indexs`` and apply:: + + allow[m, k] = (k <= q_pos[m]) OR image_tag[m] for k in [0, total) + +i.e. normal queries are causal, image-token queries can see every real key in +the request. If the Triton kernel disagrees with this reference, the kernel is +wrong. + +Run directly for quick debugging: + + python unit_tests/common/basemodel/triton_kernel/att/prefill_att/\ + test_context_attention_fwd_neo.py + +or via pytest: + + pytest unit_tests/common/basemodel/triton_kernel/att/prefill_att/\ + test_context_attention_fwd_neo.py -x -s +""" + +import math +import pytest +import torch + +from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import ( + context_attention_fwd_neo, +) + + +def torch_reference_context_attention_neo( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + b_req_idx: torch.Tensor, + b_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + b_prompt_cache_len: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_image_token_tag: torch.Tensor, +) -> torch.Tensor: + device = q.device + dtype = q.dtype + _, Hq, D = q.shape + Hk = k.shape[1] + kv_group = Hq // Hk + scale = 1.0 / math.sqrt(D) + + out = torch.empty_like(q) + + for b in range(b_seq_len.shape[0]): + req_idx = int(b_req_idx[b].item()) + total = int(b_seq_len[b].item()) + prompt = int(b_prompt_cache_len[b].item()) + new = total - prompt + if new <= 0: + continue + + q_start = int(b_start_loc[b].item()) + q_blk = q[q_start : q_start + new] # [M, Hq, D] + image_tag = b_image_token_tag[q_start : q_start + new].to(torch.bool) + + token_locs = req_to_token_indexs[req_idx, :total].to(torch.int64) + k_blk = k[token_locs] # [total, Hk, D] + v_blk = v[token_locs] + + q_pos = torch.arange(prompt, total, device=device, dtype=torch.int64) # [M] + k_pos = torch.arange(0, total, device=device, dtype=torch.int64) # [total] + causal = k_pos[None, :] <= q_pos[:, None] + allow = causal | image_tag[:, None] + + out_blk = torch.empty_like(q_blk) + for h in range(Hq): + h_k = h // kv_group + q_h = q_blk[:, h, :].to(torch.float32) + k_h = k_blk[:, h_k, :].to(torch.float32) + v_h = v_blk[:, h_k, :].to(torch.float32) + + scores = (q_h @ k_h.transpose(0, 1)) * scale + scores = torch.where(allow, scores, torch.full_like(scores, -1.0e8)) + probs = torch.softmax(scores, dim=-1) + out_h = (probs @ v_h).to(dtype) + out_blk[:, h, :] = out_h + + out[q_start : q_start + new] = out_blk + + return out + + +def _build_inputs( + batch: int, + Hq: int, + Hk: int, + D: int, + dtype: torch.dtype, + device: str, + max_new: int = 256, + max_prompt: int = 512, + image_prob: float = 0.7, + num_image_spans_max: int = 3, + image_span_len_max: int = 24, + kv_pool_slack: int = 4096, + seed: int = 0, +): + g = torch.Generator(device="cpu").manual_seed(seed) + + new_lens = torch.randint(low=1, high=max_new + 1, size=(batch,), generator=g) + prompt_lens = torch.randint(low=0, high=max_prompt + 1, size=(batch,), generator=g) + total_lens = new_lens + prompt_lens + + sum_new = int(new_lens.sum().item()) + sum_total = int(total_lens.sum().item()) + max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) + + b_start_loc = torch.zeros(batch, dtype=torch.int32) + cur = 0 + for i in range(batch): + b_start_loc[i] = cur + cur += int(new_lens[i].item()) + + # Permute so batch idx != request idx: exercises the Req_to_tokens indexing. + b_req_idx = torch.randperm(batch, generator=g).to(torch.int32) + + # Global KV pool with scattered, non-contiguous slot assignment per request. + base = 1024 + kv_pool_size = base + sum_total + kv_pool_slack + pool = torch.randperm(kv_pool_size - base, generator=g)[:sum_total] + base + + req_to_token_indexs = torch.zeros((batch, max_total_len), dtype=torch.int32) + p = 0 + for r_logical, req_id in enumerate(b_req_idx.tolist()): + L = int(total_lens[r_logical].item()) + req_to_token_indexs[req_id, :L] = pool[p : p + L].to(torch.int32) + p += L + + # Randomly place contiguous image-token spans inside each batch's NEW region. + b_image_token_tag = torch.zeros(sum_new, dtype=torch.bool) + for i in range(batch): + M = int(new_lens[i].item()) + if M < 2: + continue + if torch.rand((), generator=g).item() > image_prob: + continue + n_spans = int(torch.randint(1, num_image_spans_max + 1, (1,), generator=g).item()) + start_pack = int(b_start_loc[i].item()) + for _ in range(n_spans): + span_len = int( + torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item() + ) + span_len = min(span_len, M) + s_rel = int(torch.randint(0, M - span_len + 1, (1,), generator=g).item()) + b_image_token_tag[start_pack + s_rel : start_pack + s_rel + span_len] = True + + b_seq_len = total_lens.to(torch.int32) + b_prompt_cache_len = prompt_lens.to(torch.int32) + + # position_ids[0]: kernel API still requires it even though its current + # mask logic only reads b_image_token_tag. + position_ids_0 = torch.empty(sum_new, dtype=torch.int32) + for i in range(batch): + M = int(new_lens[i].item()) + P = int(prompt_lens[i].item()) + s = int(b_start_loc[i].item()) + position_ids_0[s : s + M] = torch.arange(P, P + M, dtype=torch.int32) + + q = torch.randn((sum_new, Hq, D), dtype=dtype, device=device) + k = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device) + v = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device) + o = torch.empty_like(q) + + return dict( + q=q, + k=k, + v=v, + o=o, + position_ids_0=position_ids_0.to(device), + b_req_idx=b_req_idx.to(device), + b_start_loc=b_start_loc.to(device), + b_seq_len=b_seq_len.to(device), + b_prompt_cache_len=b_prompt_cache_len.to(device), + max_new_len=max_new_len, + req_to_token_indexs=req_to_token_indexs.to(device), + b_image_token_tag=b_image_token_tag.to(device), + new_lens=new_lens, + prompt_lens=prompt_lens, + ) + + +def _report_per_batch_error(out_triton, out_ref, new_lens, b_start_loc, image_tag, tag=""): + print(f"\n[{tag}] per-batch error breakdown (abs / rel / cos):") + for i in range(new_lens.shape[0]): + s = int(b_start_loc[i].item()) + m = int(new_lens[i].item()) + if m == 0: + continue + a = out_triton[s : s + m].float() + b = out_ref[s : s + m].float() + abs_err = (a - b).abs().max().item() + denom = b.abs().max().item() + 1e-6 + rel_err = abs_err / denom + cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + n_img = int(image_tag[s : s + m].sum().item()) + print( + f" batch {i:02d} | M={m:4d} | image_tokens={n_img:4d} | " + f"max_abs={abs_err:.4e} | max_rel={rel_err:.4e} | cos={cos:.6f}" + ) + + +def _run_case( + batch: int, + Hq: int, + Hk: int, + D: int, + dtype: torch.dtype, + seed: int, + max_new: int, + max_prompt: int, + atol: float = 5e-2, + rtol: float = 5e-2, + cos_threshold: float = 0.99, + verbose: bool = True, +): + assert Hq % Hk == 0 + device = "cuda" + + inputs = _build_inputs( + batch=batch, + Hq=Hq, + Hk=Hk, + D=D, + dtype=dtype, + device=device, + max_new=max_new, + max_prompt=max_prompt, + seed=seed, + ) + + context_attention_fwd_neo( + inputs["q"], + inputs["k"], + inputs["v"], + inputs["o"], + inputs["position_ids_0"], + inputs["b_req_idx"], + inputs["b_start_loc"], + inputs["b_seq_len"], + inputs["b_prompt_cache_len"], + inputs["max_new_len"], + inputs["req_to_token_indexs"], + inputs["b_image_token_tag"], + ) + out_triton = inputs["o"] + + out_ref = torch_reference_context_attention_neo( + inputs["q"], + inputs["k"], + inputs["v"], + inputs["b_req_idx"], + inputs["b_start_loc"], + inputs["b_seq_len"], + inputs["b_prompt_cache_len"], + inputs["req_to_token_indexs"], + inputs["b_image_token_tag"], + ) + + a = out_triton.float() + b = out_ref.float() + abs_err = (a - b).abs().max().item() + denom = b.abs().max().item() + 1e-6 + rel_err = abs_err / denom + cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + + n_image = int(inputs["b_image_token_tag"].sum().item()) + n_tokens = int(inputs["b_image_token_tag"].numel()) + if verbose: + print( + f"\ncase: batch={batch} Hq={Hq} Hk={Hk} D={D} dtype={dtype} " + f"seed={seed} image_tokens={n_image}/{n_tokens}" + ) + print( + f" global: max_abs={abs_err:.4e} max_rel={rel_err:.4e} cos={cos:.6f} " + f"(allclose atol={atol}, rtol={rtol}? " + f"{torch.allclose(a, b, atol=atol, rtol=rtol)})" + ) + _report_per_batch_error( + out_triton, + out_ref, + inputs["new_lens"], + inputs["b_start_loc"], + inputs["b_image_token_tag"], + tag=f"seed={seed}", + ) + + return abs_err, rel_err, cos + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA") +@pytest.mark.parametrize( + "batch,Hq,Hk,D,dtype,seed,max_new,max_prompt", + [ + (4, 8, 2, 128, torch.bfloat16, 0, 128, 256), + (4, 8, 2, 128, torch.bfloat16, 1, 256, 512), + (8, 16, 4, 128, torch.bfloat16, 2, 256, 512), + (16, 28, 4, 128, torch.bfloat16, 3, 128, 256), + (4, 8, 2, 128, torch.float16, 4, 256, 512), + (4, 8, 8, 64, torch.bfloat16, 5, 128, 256), + (3, 8, 2, 128, torch.bfloat16, 6, 8, 1024), + ], +) +def test_context_attention_fwd_neo(batch, Hq, Hk, D, dtype, seed, max_new, max_prompt): + abs_err, rel_err, cos = _run_case( + batch=batch, + Hq=Hq, + Hk=Hk, + D=D, + dtype=dtype, + seed=seed, + max_new=max_new, + max_prompt=max_prompt, + verbose=True, + ) + assert cos > 0.99, f"cosine similarity too low: {cos}" + assert rel_err < 5e-2, f"max relative error too large: {rel_err}" + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA available.") + raise SystemExit(0) + + torch.manual_seed(0) + + cases = [ + dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.bfloat16, seed=0, max_new=128, max_prompt=256), + dict(batch=8, Hq=16, Hk=4, D=128, dtype=torch.bfloat16, seed=1, max_new=256, max_prompt=512), + dict(batch=16, Hq=28, Hk=4, D=128, dtype=torch.bfloat16, seed=2, max_new=128, max_prompt=256), + dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.float16, seed=3, max_new=256, max_prompt=512), + ] + + for cfg in cases: + _run_case(**cfg, verbose=True) From 35906f5b2422e0bc3f3600c9fdd6788020c25330 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 20 Apr 2026 14:02:58 +0000 Subject: [PATCH 27/41] pass unit test --- .../att/prefill_att/test_fa3_neo.py | 548 ++++++++++++++++++ 1 file changed, 548 insertions(+) create mode 100644 unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py new file mode 100644 index 0000000000..79b1e8e343 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py @@ -0,0 +1,548 @@ +"""Unit test for the FA3-based prefill path with image-token support. + +This test pre-wires a call to ``flash_attn_with_kvcache`` with an +``image_token_tag`` keyword argument. The expectation is that ``fa3-neo``'s +``flash_attn_with_kvcache`` will be extended with an optional +``image_token_tag`` parameter that, for queries flagged as image tokens, +relaxes the causal mask so they can attend bidirectionally to every real key +in the request. + +Torch reference expresses the *semantics* of the attention, not FA3's internal +tiling — it has no notion of BLOCK_N / BLOCK_M. For each batch element we +gather K/V for the whole request (prompt + new tokens) and apply:: + + allow[m, k] = (k <= q_pos[m]) OR image_tag[m] for k in [0, total) + +i.e. normal queries are causal, image-token queries can see every real key in +the request. If FA3 disagrees with this reference, the kernel is wrong. + +Run directly for quick debugging: + + python unit_tests/common/basemodel/triton_kernel/att/prefill_att/\ + test_fa3_neo.py + +or via pytest: + + pytest unit_tests/common/basemodel/triton_kernel/att/prefill_att/\ + test_fa3_neo.py -x -s +""" + +import math +import pytest +import torch + +from flash_attn_interface import flash_attn_with_kvcache + +try: + import triton + import triton.testing as triton_testing +except ImportError: + triton = None + triton_testing = None + +try: + from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import ( + context_attention_fwd_neo, + ) +except ImportError: + context_attention_fwd_neo = None + + +def torch_reference_context_attention_neo( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + b_req_idx: torch.Tensor, + b_q_start_loc: torch.Tensor, + b_seq_len: torch.Tensor, + b_prompt_cache_len: torch.Tensor, + req_to_token_indexs: torch.Tensor, + b_image_token_tag: torch.Tensor, +) -> torch.Tensor: + device = q.device + dtype = q.dtype + _, Hq, D = q.shape + Hk = k.shape[1] + kv_group = Hq // Hk + scale = 1.0 / math.sqrt(D) + + out = torch.empty_like(q) + + for b in range(b_seq_len.shape[0]): + req_idx = int(b_req_idx[b].item()) + seq_len = int(b_seq_len[b].item()) + prompt_cache_len = int(b_prompt_cache_len[b].item()) + q_seq_len = seq_len - prompt_cache_len + if q_seq_len <= 0: + continue + + q_start = int(b_q_start_loc[b].item()) + q_blk = q[q_start : q_start + q_seq_len] # [M, Hq, D] + image_tag = b_image_token_tag[q_start : q_start + q_seq_len].to(torch.bool) + + token_locs = req_to_token_indexs[req_idx, :seq_len].to(torch.int64) + k_blk = k[token_locs] # [seq_len, Hk, D] + v_blk = v[token_locs] + + q_pos = torch.arange(prompt_cache_len, seq_len, device=device, dtype=torch.int64) # [M] + k_pos = torch.arange(0, seq_len, device=device, dtype=torch.int64) # [seq_len] + causal = k_pos[None, :] <= q_pos[:, None] + allow = causal | image_tag[:, None] + + out_blk = torch.empty_like(q_blk) + for h in range(Hq): + h_k = h // kv_group + q_h = q_blk[:, h, :].to(torch.float32) + k_h = k_blk[:, h_k, :].to(torch.float32) + v_h = v_blk[:, h_k, :].to(torch.float32) + + scores = (q_h @ k_h.transpose(0, 1)) * scale + scores = torch.where(allow, scores, torch.full_like(scores, -1.0e8)) + probs = torch.softmax(scores, dim=-1) + out_h = (probs @ v_h).to(dtype) + out_blk[:, h, :] = out_h + + out[q_start : q_start + q_seq_len] = out_blk + + return out + + +def _build_inputs( + batch: int, + Hq: int, + Hk: int, + D: int, + dtype: torch.dtype, + device: str, + max_q_seq_len: int = 256, + max_prompt_cache_len: int = 512, + image_prob: float = 0.7, + num_image_spans_max: int = 3, + image_span_len_max: int = 24, + kv_pool_slack: int = 4096, + seed: int = 0, +): + """Build one realistic prefill batch. + + Naming matches lightllm's infer_state: + - ``q_seq_len`` = number of new Q tokens in this prefill call + - ``prompt_cache_len`` = length of the already-cached prefix for this req + - ``seq_len`` = prompt_cache_len + q_seq_len (total KV length) + """ + g = torch.Generator(device="cpu").manual_seed(seed) + + q_seq_lens = torch.randint(low=1, high=max_q_seq_len + 1, size=(batch,), generator=g) + prompt_cache_lens = torch.randint(low=0, high=max_prompt_cache_len + 1, size=(batch,), generator=g) + seq_lens = q_seq_lens + prompt_cache_lens + + sum_q = int(q_seq_lens.sum().item()) + sum_total = int(seq_lens.sum().item()) + max_seq_len_in_batch = int(seq_lens.max().item()) + max_q_seq_len_in_batch = int(q_seq_lens.max().item()) + + b_q_start_loc = torch.zeros(batch, dtype=torch.int32) + cur = 0 + for i in range(batch): + b_q_start_loc[i] = cur + cur += int(q_seq_lens[i].item()) + + # Permute so batch idx != request idx: exercises the page_table indexing. + b_req_idx = torch.randperm(batch, generator=g).to(torch.int32) + + # Global KV pool with scattered, non-contiguous slot assignment per request. + base = 1024 + kv_pool_size = base + sum_total + kv_pool_slack + pool = torch.randperm(kv_pool_size - base, generator=g)[:sum_total] + base + + req_to_token_indexs = torch.zeros((batch, max_seq_len_in_batch), dtype=torch.int32) + p = 0 + for r_logical, req_id in enumerate(b_req_idx.tolist()): + L = int(seq_lens[r_logical].item()) + req_to_token_indexs[req_id, :L] = pool[p : p + L].to(torch.int32) + p += L + + # Randomly place contiguous image-token spans inside each batch's new-Q region. + b_image_token_tag = torch.zeros(sum_q, dtype=torch.bool) + for i in range(batch): + M = int(q_seq_lens[i].item()) + if M < 2: + continue + if torch.rand((), generator=g).item() > image_prob: + continue + n_spans = int(torch.randint(1, num_image_spans_max + 1, (1,), generator=g).item()) + start_pack = int(b_q_start_loc[i].item()) + for _ in range(n_spans): + span_len = int( + torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item() + ) + span_len = min(span_len, M) + s_rel = int(torch.randint(0, M - span_len + 1, (1,), generator=g).item()) + b_image_token_tag[start_pack + s_rel : start_pack + s_rel + span_len] = True + + b_seq_len = seq_lens.to(torch.int32) + b_prompt_cache_len = prompt_cache_lens.to(torch.int32) + + # Per-batch last image-token index in *batch-local* packed-q coordinates. + # Matches NeoChatInferStateInfo._compute_b_max_image_q_idx semantics: + # shape int32[batch]; value == -1 means that batch has no image tokens. + # Computed on CPU (b_image_token_tag is still on CPU here) so no D2H sync + # — keeps the eventual flash_attn_with_kvcache call CUDA-graph-safe. + b_max_image_q_idx_cpu = torch.full((batch,), -1, dtype=torch.int32) + for b in range(batch): + start = int(b_q_start_loc[b].item()) + length = int(q_seq_lens[b].item()) + seg = b_image_token_tag[start : start + length] + idx = torch.nonzero(seg, as_tuple=False) + if idx.numel() > 0: + b_max_image_q_idx_cpu[b] = int(idx[-1, 0].item()) + + q = torch.randn((sum_q, Hq, D), dtype=dtype, device=device) + k = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device) + v = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device) + + return dict( + q=q, + k=k, + v=v, + b_req_idx=b_req_idx.to(device), + b_q_start_loc=b_q_start_loc.to(device), + b_seq_len=b_seq_len.to(device), + b_prompt_cache_len=b_prompt_cache_len.to(device), + max_seq_len_in_batch=max_seq_len_in_batch, + max_q_seq_len_in_batch=max_q_seq_len_in_batch, + req_to_token_indexs=req_to_token_indexs.to(device), + b_image_token_tag=b_image_token_tag.to(device), + b_max_image_q_idx=b_max_image_q_idx_cpu.to(device), + q_seq_lens=q_seq_lens, + prompt_cache_lens=prompt_cache_lens, + ) + + +def _fa3_prefill_with_image_tag(inputs: dict) -> torch.Tensor: + """Drive ``flash_attn_with_kvcache`` with the same prefill semantics as + ``Fa3PrefillAttState._nomarl_prefill_att`` plus an optional + ``image_token_tag`` kwarg for image-token bidirectional attention. + """ + q = inputs["q"] + k = inputs["k"] + v = inputs["v"] + device = q.device + + # Build page_table[b, p] = req_to_token_indexs[b_req_idx[b], p]. + page_table = inputs["req_to_token_indexs"][ + inputs["b_req_idx"].long(), : inputs["max_seq_len_in_batch"] + ].to(torch.int32) + + q_seq_lens_t = inputs["b_seq_len"].to(torch.int32) - inputs["b_prompt_cache_len"].to(torch.int32) + + cu_seqlens_q = torch.zeros(q_seq_lens_t.shape[0] + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = q_seq_lens_t.cumsum(0).to(torch.int32) + + cu_seqlens_k = torch.zeros(inputs["b_seq_len"].shape[0] + 1, dtype=torch.int32, device=device) + cu_seqlens_k[1:] = inputs["b_seq_len"].cumsum(0).to(torch.int32) + + sm_scale = 1.0 / math.sqrt(q.shape[-1]) + + # page_size = 1 paged KV cache view. + k_cache = k.view(k.shape[0], 1, k.shape[1], k.shape[2]) + v_cache = v.view(v.shape[0], 1, v.shape[1], v.shape[2]) + + o = flash_attn_with_kvcache( + q=q, + k_cache=k_cache, + v_cache=v_cache, + page_table=page_table, + cache_seqlens=inputs["b_seq_len"], + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k, + max_seqlen_q=inputs["max_q_seq_len_in_batch"], + softmax_scale=sm_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=None, + v_descale=None, + return_softmax_lse=False, + # image-token bidirectional attention. Packed like q (shape [sum_q], + # bool). Rows where the tag is True are allowed to attend to every + # real key in the request (not just the causal prefix). + # + # b_max_image_q_idx is int32[batch]: per-batch last image-token index + # in batch-local coordinates, -1 if no image in that batch. Lets the + # fa3 kernel skip n_block_max extension for text-only requests, which + # is critical for mixed-modality batches. Pre-computed in + # _build_inputs (host-side) so this call is CUDA-graph safe. + image_token_tag=inputs["b_image_token_tag"], + max_image_q_idx=inputs["b_max_image_q_idx"], + ) + return o + + +def _report_per_batch_error(out_fa3, out_ref, q_seq_lens, b_q_start_loc, image_tag, tag=""): + print(f"\n[{tag}] per-batch error breakdown (abs / rel / cos):") + for i in range(q_seq_lens.shape[0]): + s = int(b_q_start_loc[i].item()) + m = int(q_seq_lens[i].item()) + if m == 0: + continue + a = out_fa3[s : s + m].float() + b = out_ref[s : s + m].float() + abs_err = (a - b).abs().max().item() + denom = b.abs().max().item() + 1e-6 + rel_err = abs_err / denom + cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + n_img = int(image_tag[s : s + m].sum().item()) + print( + f" batch {i:02d} | M={m:4d} | image_tokens={n_img:4d} | " + f"max_abs={abs_err:.4e} | max_rel={rel_err:.4e} | cos={cos:.6f}" + ) + + +def _run_case( + batch: int, + Hq: int, + Hk: int, + D: int, + dtype: torch.dtype, + seed: int, + max_q_seq_len: int, + max_prompt_cache_len: int, + atol: float = 5e-2, + rtol: float = 5e-2, + cos_threshold: float = 0.99, + verbose: bool = True, +): + assert Hq % Hk == 0 + device = "cuda" + + inputs = _build_inputs( + batch=batch, + Hq=Hq, + Hk=Hk, + D=D, + dtype=dtype, + device=device, + max_q_seq_len=max_q_seq_len, + max_prompt_cache_len=max_prompt_cache_len, + seed=seed, + ) + + out_fa3 = _fa3_prefill_with_image_tag(inputs) + + out_ref = torch_reference_context_attention_neo( + inputs["q"], + inputs["k"], + inputs["v"], + inputs["b_req_idx"], + inputs["b_q_start_loc"], + inputs["b_seq_len"], + inputs["b_prompt_cache_len"], + inputs["req_to_token_indexs"], + inputs["b_image_token_tag"], + ) + + a = out_fa3.float().reshape_as(out_ref.float()) + b = out_ref.float() + abs_err = (a - b).abs().max().item() + denom = b.abs().max().item() + 1e-6 + rel_err = abs_err / denom + cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() + + n_image = int(inputs["b_image_token_tag"].sum().item()) + n_tokens = int(inputs["b_image_token_tag"].numel()) + if verbose: + print( + f"\ncase: batch={batch} Hq={Hq} Hk={Hk} D={D} dtype={dtype} " + f"seed={seed} image_tokens={n_image}/{n_tokens}" + ) + print( + f" global: max_abs={abs_err:.4e} max_rel={rel_err:.4e} cos={cos:.6f} " + f"(allclose atol={atol}, rtol={rtol}? " + f"{torch.allclose(a, b, atol=atol, rtol=rtol)})" + ) + _report_per_batch_error( + out_fa3, + out_ref, + inputs["q_seq_lens"], + inputs["b_q_start_loc"], + inputs["b_image_token_tag"], + tag=f"seed={seed}", + ) + + return abs_err, rel_err, cos + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA") +@pytest.mark.skipif(flash_attn_with_kvcache is None, reason="fa3 not available") +@pytest.mark.parametrize( + "batch,Hq,Hk,D,dtype,seed,max_q_seq_len,max_prompt_cache_len", + [ + (4, 8, 2, 128, torch.bfloat16, 0, 128, 256), + (4, 8, 2, 128, torch.bfloat16, 1, 256, 512), + (8, 16, 4, 128, torch.bfloat16, 2, 256, 512), + (16, 28, 4, 128, torch.bfloat16, 3, 128, 256), + (4, 8, 2, 128, torch.float16, 4, 256, 512), + (4, 8, 8, 64, torch.bfloat16, 5, 128, 256), + (3, 8, 2, 128, torch.bfloat16, 6, 8, 1024), + ], +) +def test_fa3_neo_prefill_with_image_tag( + batch, Hq, Hk, D, dtype, seed, max_q_seq_len, max_prompt_cache_len +): + abs_err, rel_err, cos = _run_case( + batch=batch, + Hq=Hq, + Hk=Hk, + D=D, + dtype=dtype, + seed=seed, + max_q_seq_len=max_q_seq_len, + max_prompt_cache_len=max_prompt_cache_len, + verbose=True, + ) + assert cos > 0.99, f"cosine similarity too low: {cos}" + assert rel_err < 5e-2, f"max relative error too large: {rel_err}" + + +def _bench_case( + batch: int, + Hq: int, + Hk: int, + D: int, + dtype: torch.dtype, + seed: int, + max_q_seq_len: int, + max_prompt_cache_len: int, + rep_ms: int = 100, + warmup_iters: int = 3, +): + """Compare FA3 (with image_token_tag) vs the original Triton + ``context_attention_fwd_neo`` using ``triton.testing.do_bench_cudagraph``. + + Both kernels are captured into a CUDA graph so scheduling/launch overhead + is minimized and the measurement reflects the kernel cost. + """ + assert Hq % Hk == 0 + device = "cuda" + + inputs = _build_inputs( + batch=batch, + Hq=Hq, + Hk=Hk, + D=D, + dtype=dtype, + device=device, + max_q_seq_len=max_q_seq_len, + max_prompt_cache_len=max_prompt_cache_len, + seed=seed, + ) + + # --- fa3 runner: output tensor is allocated inside flash_attn_with_kvcache. + def fa3_run(): + return _fa3_prefill_with_image_tag(inputs) + + # --- triton runner: pre-allocate o & position_ids so the graph captures + # only the kernel launch. + o_triton = torch.empty_like(inputs["q"]) + # Kernel signature requires position_ids but the current masking path does + # not read it; zeros are fine for perf measurement. + position_ids_0 = torch.zeros( + inputs["q"].shape[0], dtype=torch.int32, device=inputs["q"].device + ) + + def triton_run(): + context_attention_fwd_neo( + inputs["q"], + inputs["k"], + inputs["v"], + o_triton, + position_ids_0, + inputs["b_req_idx"], + inputs["b_q_start_loc"], + inputs["b_seq_len"], + inputs["b_prompt_cache_len"], + inputs["max_q_seq_len_in_batch"], + inputs["req_to_token_indexs"], + inputs["b_image_token_tag"], + ) + + # Warm up outside the graph capture so lazy allocations / autotune happen. + for _ in range(warmup_iters): + fa3_run() + triton_run() + torch.cuda.synchronize() + + fa3_ms = triton_testing.do_bench_cudagraph(fa3_run, rep=rep_ms) + triton_ms = triton_testing.do_bench_cudagraph(triton_run, rep=rep_ms) + + n_image = int(inputs["b_image_token_tag"].sum().item()) + n_tokens = int(inputs["b_image_token_tag"].numel()) + sum_kv = int(inputs["b_seq_len"].sum().item()) + speedup = triton_ms / fa3_ms if fa3_ms > 0 else float("inf") + + print( + f"bench: batch={batch} Hq={Hq} Hk={Hk} D={D} dtype={str(dtype).split('.')[-1]:<8s} " + f"max_q_seq_len={max_q_seq_len:4d} max_prompt_cache_len={max_prompt_cache_len:4d} " + f"image_tokens={n_image:4d}/{n_tokens:5d} sum_kv={sum_kv:6d} | " + f"fa3 {fa3_ms*1000:8.1f} us | triton {triton_ms*1000:8.1f} us | " + f"speedup {speedup:5.2f}x" + ) + + return fa3_ms, triton_ms + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA available.") + raise SystemExit(0) + if flash_attn_with_kvcache is None: + print("fa3 flash_attn_with_kvcache not available (sgl_kernel missing?).") + raise SystemExit(0) + + torch.manual_seed(0) + + cases = [ + dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.bfloat16, seed=0, max_q_seq_len=128, max_prompt_cache_len=256), + dict(batch=8, Hq=16, Hk=4, D=128, dtype=torch.bfloat16, seed=1, max_q_seq_len=256, max_prompt_cache_len=512), + dict(batch=16, Hq=28, Hk=4, D=128, dtype=torch.bfloat16, seed=2, max_q_seq_len=128, max_prompt_cache_len=256), + dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.float16, seed=3, max_q_seq_len=256, max_prompt_cache_len=512), + ] + + print("=" * 100) + print("Correctness") + print("=" * 100) + for cfg in cases: + _run_case(**cfg, verbose=True) + + if triton_testing is None or context_attention_fwd_neo is None: + print("\nSkipping benchmark: triton or context_attention_fwd_neo not available.") + raise SystemExit(0) + + print("\n" + "=" * 100) + print("Benchmark (triton.testing.do_bench_cudagraph)") + print("=" * 100) + + # Cold-prefill sweep: max_prompt_cache_len=0 so seq_len == q_seq_len. + # Head shape matches neo_chat_moe / Qwen3 llm_config: + # num_attention_heads = 32, num_key_value_heads = 8, head_dim = 128 + # (GQA ratio 4:1) + bench_batches = [8, 16, 32, 64, 128] + bench_q_seq_lens = [1024, 4096, 8192] + + bench_cases = [ + dict( + batch=b, + Hq=32, + Hk=8, + D=128, + dtype=torch.bfloat16, + seed=0, + max_q_seq_len=s, + max_prompt_cache_len=0, + ) + for b in bench_batches + for s in bench_q_seq_lens + ] + + for cfg in bench_cases: + _bench_case(**cfg) From 1c8e7214598108ab3668556d6117b9b94c47a2e4 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 20 Apr 2026 14:18:46 +0000 Subject: [PATCH 28/41] add fa3_neo --- .../common/basemodel/attention/base_att.py | 5 ++ lightllm/common/basemodel/attention/fa3/fp.py | 10 ++++ .../layer_infer/transformer_layer_infer.py | 54 +++++++++++++------ lightllm/models/neo_chat/model.py | 13 +++++ lightllm/models/neo_chat_moe/infer_struct.py | 10 ++++ .../layer_infer/transformer_layer_infer.py | 52 ++++++++++++------ lightllm/models/neo_chat_moe/model.py | 13 +++++ .../triton_kernel/get_neo_position.py | 27 ++++++++++ 8 files changed, 152 insertions(+), 32 deletions(-) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 1286a46ec2..1ee38a1d03 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -70,6 +70,11 @@ class AttControl: nsa_prefill_dict: Dict = None nsa_decode: bool = False nsa_decode_dict: Dict = None + image_token_tag: Optional[torch.Tensor] = None + # max_image_q_idx: int32[batch]. 每个 batch 内最后一个 image token 的 local q + # 下标;没有 image token 的 batch 填 -1。fa3 kernel 据此决定该 batch 的哪些 + # M-block 需要把 n_block_max 延长到 full seqlen_k,避免无 image 的 batch 翻倍计算。 + max_image_q_idx: Optional[torch.Tensor] = None @dataclass diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 952bb39d91..3c99c4d095 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -92,6 +92,15 @@ def _nomarl_prefill_att( k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] sm_scale = 1.0 / (Lq ** 0.5) + + # Optional image-token bidirectional attention (neo_chat*). When None, + # the kernel path reduces to plain causal fa3 (zero overhead). + extra_kwargs = {} + if att_control.image_token_tag is not None: + extra_kwargs["image_token_tag"] = att_control.image_token_tag + if att_control.max_image_q_idx is not None: + extra_kwargs["max_image_q_idx"] = att_control.max_image_q_idx + o = flash_attn_with_kvcache( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), @@ -109,6 +118,7 @@ def _nomarl_prefill_att( v_descale=v_descale, return_softmax_lse=False, sinks=sink_weight, + **extra_kwargs, ) return o diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index 4517b5688a..0628411cbf 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -1,3 +1,4 @@ +import os import torch from functools import partial from typing import Tuple @@ -12,6 +13,8 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.common.basemodel.attention.base_att import AttControl +_USE_TRITON_PREFILL = os.environ.get("LIGHTLLM_NEO_PREFILL_TRITON_BACKEND", "0").strip().lower() in ("1", "true") + class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): def __init__(self, data_type, network_config): @@ -83,23 +86,42 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC def _context_attention_kernel( self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd_neo( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] - infer_state.b_req_idx, - infer_state.b_q_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_q_seq_len, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_image_token_tag, + + + if _USE_TRITON_PREFILL: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] + infer_state.b_req_idx, + infer_state.b_q_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_q_seq_len, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_image_token_tag, + ) + return o_tensor + + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + + att_control = AttControl() + att_control.image_token_tag = getattr(infer_state, "b_image_token_tag", None) + att_control.max_image_q_idx = getattr(infer_state, "b_max_image_q_idx", None) + + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + att_control=att_control, + alloc_func=self.alloc_tensor, ) - return o_tensor + return o_tensor.view(q.shape) def _token_attention_kernel( self, diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py index 14d9f96dc7..e17ad31fa1 100644 --- a/lightllm/models/neo_chat/model.py +++ b/lightllm/models/neo_chat/model.py @@ -20,6 +20,11 @@ from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.common.basemodel.attention import ( + get_prefill_att_backend_class, + get_decode_att_backend_class, + BaseAttBackend, +) @ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3")) @@ -40,6 +45,14 @@ def __init__(self, kvargs): def _init_inferstate_cls(self): pass + def _init_att_backend(self): + self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class( + index=0, priority_list=["fa3"] + )(model=self) + self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class( + index=0, priority_list=["fa3"] + )(model=self) + def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: all_config = json.load(json_file) diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py index 1693bcb964..d7256af2b2 100644 --- a/lightllm/models/neo_chat_moe/infer_struct.py +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -20,9 +20,18 @@ def __init__(self): def init_some_extra_state(self, model: LlamaTpPartModel): LlamaInferStateInfo.init_some_extra_state(self, model) if self.is_prefill: + bsz = self.b_q_seq_len.shape[0] self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda( non_blocking=True ) + # Pre-allocate to -1 so the "no images anywhere" fast path in + # get_neo_position (which skips the triton kernel entirely) still + # yields a valid per-batch tensor. When the kernel runs, it + # overwrites every batch slot unconditionally with its computed + # value (-1 if no image lands in that batch's q window). + self.b_max_image_q_idx = torch.full( + (bsz,), -1, dtype=torch.int32, device=self.b_q_seq_len.device + ) self.position_ids = self.get_neo_position(self.multimodal_params) else: b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] @@ -97,5 +106,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: b_q_seq_len=self.b_q_seq_len, b_start_loc=self.b_q_start_loc, b_image_token_tag=self.b_image_token_tag, + b_max_image_q_idx=self.b_max_image_q_idx, ) return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 4c4d8a22ab..4fe1d2e9d1 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -1,3 +1,4 @@ +import os import torch from functools import partial from typing import Tuple @@ -12,6 +13,8 @@ from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.common.basemodel.attention.base_att import AttControl +_USE_TRITON_PREFILL = os.environ.get("LIGHTLLM_NEO_PREFILL_TRITON_BACKEND", "0").strip().lower() in ("1", "true") + class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): def __init__(self, data_type, network_config): @@ -82,23 +85,40 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC def _context_attention_kernel( self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None ) -> torch.Tensor: - o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - context_attention_fwd_neo( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - kv[:, 0 : self.tp_k_head_num_, :], - kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], - o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] - infer_state.b_req_idx, - infer_state.b_q_start_loc, - infer_state.b_seq_len, - infer_state.b_ready_cache_len, - infer_state.max_q_seq_len, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_image_token_tag, + if _USE_TRITON_PREFILL: + o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd_neo( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] + infer_state.b_req_idx, + infer_state.b_q_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_q_seq_len, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_image_token_tag, + ) + return o_tensor + + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_) + + att_control = AttControl() + att_control.image_token_tag = getattr(infer_state, "b_image_token_tag", None) + att_control.max_image_q_idx = getattr(infer_state, "b_max_image_q_idx", None) + + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, + k=_k, + v=_v, + att_control=att_control, + alloc_func=self.alloc_tensor, ) - return o_tensor + return o_tensor.view(q.shape) def _token_attention_kernel( self, diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index dba8965512..7ae3cd27c0 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -20,6 +20,11 @@ from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo +from lightllm.common.basemodel.attention import ( + get_prefill_att_backend_class, + get_decode_att_backend_class, + BaseAttBackend, +) IMG_START_TOKEN = "" IMG_END_TOKEN = "" @@ -182,6 +187,14 @@ def __init__(self, kvargs): def _init_inferstate_cls(self): pass + def _init_att_backend(self): + self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class( + index=0, priority_list=["fa3"] + )(model=self) + self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class( + index=0, priority_list=["fa3"] + )(model=self) + def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: all_config = json.load(json_file) diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py index 1a3d4af73b..223befad62 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -17,6 +17,7 @@ def _get_neo_position_triton( b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, b_image_token_tag: torch.Tensor, + b_max_image_q_idx: torch.Tensor, BLOCK_SIZE: tl.constexpr, ) -> torch.Tensor: cur_batch = tl.program_id(0) @@ -25,12 +26,30 @@ def _get_neo_position_triton( image_num = tl.load(b_image_nums + cur_batch) image_start_num = tl.load(b_image_start_num + cur_batch) start_loc = tl.load(b_start_loc + cur_batch) + + # Track per-batch last (batch-local) packed-q index where tag=True is + # written. -1 means this batch has no image tokens in its q range (either + # no images at all, or every image is entirely in prompt cache / out of + # the current q window). + max_image_q_idx = -1 + for i in range(image_num): local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) image_start_idx = start_loc + local_image_start_idx - cache_len image_len = tl.load(b_image_len + image_start_num + i) # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) + + # Batch-local index of this image's last token that actually lands in + # the current q window (after clipping to [0, q_seq_len)). Matches the + # mask applied in the stores below. + cand_raw = local_image_start_idx - cache_len + image_len - 1 + candidate = tl.minimum(cand_raw, q_seq_len - 1) + contributes = (cand_raw >= 0) & (local_image_start_idx - cache_len < q_seq_len) + max_image_q_idx = tl.where( + contributes, tl.maximum(max_image_q_idx, candidate), max_image_q_idx + ) + for j in range(0, image_len, BLOCK_SIZE): off = j + tl.arange(0, BLOCK_SIZE) # 目前没考虑视频,所以t 恒为 0 @@ -66,6 +85,8 @@ def _get_neo_position_triton( & (local_image_start_idx - cache_len + off >= 0), ) + tl.store(b_max_image_q_idx + cur_batch, max_image_q_idx) + for i in range(image_num): local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) image_len = tl.load(b_image_len + image_start_num + i) @@ -96,10 +117,12 @@ def get_neo_position_triton( b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, b_image_token_tag: torch.Tensor, + b_max_image_q_idx: torch.Tensor, ) -> torch.Tensor: batch_size = b_q_seq_len.shape[0] assert batch_size == b_image_nums.shape[0] + assert b_max_image_q_idx.shape[0] == batch_size grid = (batch_size,) BLOCK_SIZE = 64 _get_neo_position_triton[grid]( @@ -115,6 +138,7 @@ def get_neo_position_triton( b_q_seq_len=b_q_seq_len, b_start_loc=b_start_loc, b_image_token_tag=b_image_token_tag, + b_max_image_q_idx=b_max_image_q_idx, BLOCK_SIZE=BLOCK_SIZE, ) @@ -136,6 +160,7 @@ def test(): b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") + b_max_image_q_idx = torch.full((b_q_seq_len.shape[0],), -1, dtype=torch.int32, device="cuda") get_neo_position_triton( b_image_start_idx, b_image_thwd, @@ -147,10 +172,12 @@ def test(): b_q_seq_len, b_start_loc, b_image_token_tag, + b_max_image_q_idx, ) print(b_image_token_tag) print(position_ids) + print(b_max_image_q_idx) # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) # position_ids = ( From 783851265da952292a066fae8f4dc06f89a574fe Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 20 Apr 2026 15:03:32 +0000 Subject: [PATCH 29/41] import flash_attn_with_kvcache_neo --- lightllm/common/basemodel/attention/fa3/fp.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 3c99c4d095..e1ddcd8b85 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -4,6 +4,7 @@ from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_neo from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor @@ -93,13 +94,31 @@ def _nomarl_prefill_att( Lq = q.shape[-1] sm_scale = 1.0 / (Lq ** 0.5) - # Optional image-token bidirectional attention (neo_chat*). When None, - # the kernel path reduces to plain causal fa3 (zero overhead). - extra_kwargs = {} + # neo_chat*: image-token bidirectional attention requires flash_attn_interface + # (sgl_kernel's flash_attn_with_kvcache does not support image_token_tag). if att_control.image_token_tag is not None: - extra_kwargs["image_token_tag"] = att_control.image_token_tag - if att_control.max_image_q_idx is not None: - extra_kwargs["max_image_q_idx"] = att_control.max_image_q_idx + extra_kwargs = {"image_token_tag": att_control.image_token_tag} + if att_control.max_image_q_idx is not None: + extra_kwargs["max_image_q_idx"] = att_control.max_image_q_idx + o = flash_attn_with_kvcache_neo( + q=q, + k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), + v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]), + page_table=self.page_table, + cache_seqlens=self.infer_state.b_seq_len, + cu_seqlens_q=self.cu_seqlens_q, + cu_seqlens_k_new=self.cu_seqlens_k, + max_seqlen_q=self.infer_state.max_q_seq_len, + softmax_scale=sm_scale, + causal=True, + window_size=window_size, + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + **extra_kwargs, + ) + return o o = flash_attn_with_kvcache( q=q, @@ -118,7 +137,6 @@ def _nomarl_prefill_att( v_descale=v_descale, return_softmax_lse=False, sinks=sink_weight, - **extra_kwargs, ) return o From b202de03de2d813feb488911a9276fb88c9cfc3e Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 20 Apr 2026 15:46:55 +0000 Subject: [PATCH 30/41] fix --- lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py index 223befad62..9756371af2 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -44,7 +44,7 @@ def _get_neo_position_triton( # the current q window (after clipping to [0, q_seq_len)). Matches the # mask applied in the stores below. cand_raw = local_image_start_idx - cache_len + image_len - 1 - candidate = tl.minimum(cand_raw, q_seq_len - 1) + candidate = tl.minimum(cand_raw, q_seq_len - 1).to(tl.int32) contributes = (cand_raw >= 0) & (local_image_start_idx - cache_len < q_seq_len) max_image_q_idx = tl.where( contributes, tl.maximum(max_image_q_idx, candidate), max_image_q_idx From b31991509298d88f35ba0c727496cb985fb02c3c Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Tue, 21 Apr 2026 06:31:26 +0000 Subject: [PATCH 31/41] reduce useless nblock --- .../triton_kernel/context_attention_fwd_neo.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py index a745b7c9da..c9795e113a 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -77,12 +77,15 @@ def _fwd_kernel( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32) - block_end_loc = total_len - # absolute q positions in the request q_pos = prompt_cache_len + offs_m # [M] q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False) + # per-M-block: only scan full K range if this block has image tokens + has_image = tl.reduce_or(q_image_token_tag.to(tl.int32), axis=0) > 0 + causal_end = tl.minimum(prompt_cache_len + block_start_loc + BLOCK_M, total_len) + block_end_loc = tl.where(has_image, total_len, causal_end) + for start_n in range(0, block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) From 468ea9839926812e89717479e40fbc13a12ec0ed Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Tue, 21 Apr 2026 07:26:34 +0000 Subject: [PATCH 32/41] delete max_image_q_idx and reduce_or in kernel --- .../common/basemodel/attention/base_att.py | 4 --- lightllm/common/basemodel/attention/fa3/fp.py | 2 -- .../layer_infer/transformer_layer_infer.py | 1 - lightllm/models/neo_chat_moe/infer_struct.py | 9 ------- .../layer_infer/transformer_layer_infer.py | 1 - .../triton_kernel/get_neo_position.py | 25 ----------------- .../att/prefill_att/test_fa3_neo.py | 27 +++---------------- 7 files changed, 4 insertions(+), 65 deletions(-) diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 1ee38a1d03..23a02bba13 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -71,10 +71,6 @@ class AttControl: nsa_decode: bool = False nsa_decode_dict: Dict = None image_token_tag: Optional[torch.Tensor] = None - # max_image_q_idx: int32[batch]. 每个 batch 内最后一个 image token 的 local q - # 下标;没有 image token 的 batch 填 -1。fa3 kernel 据此决定该 batch 的哪些 - # M-block 需要把 n_block_max 延长到 full seqlen_k,避免无 image 的 batch 翻倍计算。 - max_image_q_idx: Optional[torch.Tensor] = None @dataclass diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index e1ddcd8b85..8ce12ad335 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -98,8 +98,6 @@ def _nomarl_prefill_att( # (sgl_kernel's flash_attn_with_kvcache does not support image_token_tag). if att_control.image_token_tag is not None: extra_kwargs = {"image_token_tag": att_control.image_token_tag} - if att_control.max_image_q_idx is not None: - extra_kwargs["max_image_q_idx"] = att_control.max_image_q_idx o = flash_attn_with_kvcache_neo( q=q, k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]), diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index 0628411cbf..1f5f223654 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -112,7 +112,6 @@ def _context_attention_kernel( att_control = AttControl() att_control.image_token_tag = getattr(infer_state, "b_image_token_tag", None) - att_control.max_image_q_idx = getattr(infer_state, "b_max_image_q_idx", None) o_tensor = infer_state.prefill_att_state.prefill_att( q=_q, diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py index d7256af2b2..b8e6483def 100644 --- a/lightllm/models/neo_chat_moe/infer_struct.py +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -24,14 +24,6 @@ def init_some_extra_state(self, model: LlamaTpPartModel): self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda( non_blocking=True ) - # Pre-allocate to -1 so the "no images anywhere" fast path in - # get_neo_position (which skips the triton kernel entirely) still - # yields a valid per-batch tensor. When the kernel runs, it - # overwrites every batch slot unconditionally with its computed - # value (-1 if no image lands in that batch's q window). - self.b_max_image_q_idx = torch.full( - (bsz,), -1, dtype=torch.int32, device=self.b_q_seq_len.device - ) self.position_ids = self.get_neo_position(self.multimodal_params) else: b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])] @@ -106,6 +98,5 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: b_q_seq_len=self.b_q_seq_len, b_start_loc=self.b_q_start_loc, b_image_token_tag=self.b_image_token_tag, - b_max_image_q_idx=self.b_max_image_q_idx, ) return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 4fe1d2e9d1..eab7b575bf 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -109,7 +109,6 @@ def _context_attention_kernel( att_control = AttControl() att_control.image_token_tag = getattr(infer_state, "b_image_token_tag", None) - att_control.max_image_q_idx = getattr(infer_state, "b_max_image_q_idx", None) o_tensor = infer_state.prefill_att_state.prefill_att( q=_q, diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py index 9756371af2..dc57870bf1 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -17,7 +17,6 @@ def _get_neo_position_triton( b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, b_image_token_tag: torch.Tensor, - b_max_image_q_idx: torch.Tensor, BLOCK_SIZE: tl.constexpr, ) -> torch.Tensor: cur_batch = tl.program_id(0) @@ -27,12 +26,6 @@ def _get_neo_position_triton( image_start_num = tl.load(b_image_start_num + cur_batch) start_loc = tl.load(b_start_loc + cur_batch) - # Track per-batch last (batch-local) packed-q index where tag=True is - # written. -1 means this batch has no image tokens in its q range (either - # no images at all, or every image is entirely in prompt cache / out of - # the current q window). - max_image_q_idx = -1 - for i in range(image_num): local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) image_start_idx = start_loc + local_image_start_idx - cache_len @@ -40,16 +33,6 @@ def _get_neo_position_triton( # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) - # Batch-local index of this image's last token that actually lands in - # the current q window (after clipping to [0, q_seq_len)). Matches the - # mask applied in the stores below. - cand_raw = local_image_start_idx - cache_len + image_len - 1 - candidate = tl.minimum(cand_raw, q_seq_len - 1).to(tl.int32) - contributes = (cand_raw >= 0) & (local_image_start_idx - cache_len < q_seq_len) - max_image_q_idx = tl.where( - contributes, tl.maximum(max_image_q_idx, candidate), max_image_q_idx - ) - for j in range(0, image_len, BLOCK_SIZE): off = j + tl.arange(0, BLOCK_SIZE) # 目前没考虑视频,所以t 恒为 0 @@ -85,8 +68,6 @@ def _get_neo_position_triton( & (local_image_start_idx - cache_len + off >= 0), ) - tl.store(b_max_image_q_idx + cur_batch, max_image_q_idx) - for i in range(image_num): local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) image_len = tl.load(b_image_len + image_start_num + i) @@ -117,12 +98,10 @@ def get_neo_position_triton( b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, b_image_token_tag: torch.Tensor, - b_max_image_q_idx: torch.Tensor, ) -> torch.Tensor: batch_size = b_q_seq_len.shape[0] assert batch_size == b_image_nums.shape[0] - assert b_max_image_q_idx.shape[0] == batch_size grid = (batch_size,) BLOCK_SIZE = 64 _get_neo_position_triton[grid]( @@ -138,7 +117,6 @@ def get_neo_position_triton( b_q_seq_len=b_q_seq_len, b_start_loc=b_start_loc, b_image_token_tag=b_image_token_tag, - b_max_image_q_idx=b_max_image_q_idx, BLOCK_SIZE=BLOCK_SIZE, ) @@ -160,7 +138,6 @@ def test(): b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda") - b_max_image_q_idx = torch.full((b_q_seq_len.shape[0],), -1, dtype=torch.int32, device="cuda") get_neo_position_triton( b_image_start_idx, b_image_thwd, @@ -172,12 +149,10 @@ def test(): b_q_seq_len, b_start_loc, b_image_token_tag, - b_max_image_q_idx, ) print(b_image_token_tag) print(position_ids) - print(b_max_image_q_idx) # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) # position_ids = ( diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py index 79b1e8e343..9a635f0445 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py @@ -182,20 +182,6 @@ def _build_inputs( b_seq_len = seq_lens.to(torch.int32) b_prompt_cache_len = prompt_cache_lens.to(torch.int32) - # Per-batch last image-token index in *batch-local* packed-q coordinates. - # Matches NeoChatInferStateInfo._compute_b_max_image_q_idx semantics: - # shape int32[batch]; value == -1 means that batch has no image tokens. - # Computed on CPU (b_image_token_tag is still on CPU here) so no D2H sync - # — keeps the eventual flash_attn_with_kvcache call CUDA-graph-safe. - b_max_image_q_idx_cpu = torch.full((batch,), -1, dtype=torch.int32) - for b in range(batch): - start = int(b_q_start_loc[b].item()) - length = int(q_seq_lens[b].item()) - seg = b_image_token_tag[start : start + length] - idx = torch.nonzero(seg, as_tuple=False) - if idx.numel() > 0: - b_max_image_q_idx_cpu[b] = int(idx[-1, 0].item()) - q = torch.randn((sum_q, Hq, D), dtype=dtype, device=device) k = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device) v = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device) @@ -212,7 +198,6 @@ def _build_inputs( max_q_seq_len_in_batch=max_q_seq_len_in_batch, req_to_token_indexs=req_to_token_indexs.to(device), b_image_token_tag=b_image_token_tag.to(device), - b_max_image_q_idx=b_max_image_q_idx_cpu.to(device), q_seq_lens=q_seq_lens, prompt_cache_lens=prompt_cache_lens, ) @@ -266,14 +251,9 @@ def _fa3_prefill_with_image_tag(inputs: dict) -> torch.Tensor: # image-token bidirectional attention. Packed like q (shape [sum_q], # bool). Rows where the tag is True are allowed to attend to every # real key in the request (not just the causal prefix). - # - # b_max_image_q_idx is int32[batch]: per-batch last image-token index - # in batch-local coordinates, -1 if no image in that batch. Lets the - # fa3 kernel skip n_block_max extension for text-only requests, which - # is critical for mixed-modality batches. Pre-computed in - # _build_inputs (host-side) so this call is CUDA-graph safe. + # The kernel uses warp OR reduce to detect image tokens per M-block + # and extends n_block_max for full attention automatically. image_token_tag=inputs["b_image_token_tag"], - max_image_q_idx=inputs["b_max_image_q_idx"], ) return o @@ -505,7 +485,8 @@ def triton_run(): dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.bfloat16, seed=0, max_q_seq_len=128, max_prompt_cache_len=256), dict(batch=8, Hq=16, Hk=4, D=128, dtype=torch.bfloat16, seed=1, max_q_seq_len=256, max_prompt_cache_len=512), dict(batch=16, Hq=28, Hk=4, D=128, dtype=torch.bfloat16, seed=2, max_q_seq_len=128, max_prompt_cache_len=256), - dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.float16, seed=3, max_q_seq_len=256, max_prompt_cache_len=512), + # FP16 case disabled: current build compiled without FP16 support + # dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.float16, seed=3, max_q_seq_len=256, max_prompt_cache_len=512), ] print("=" * 100) From 77dc66fa71122710b85afd2cb94c96dc26046c4e Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Tue, 21 Apr 2026 08:22:01 +0000 Subject: [PATCH 33/41] use triton when import fa3 failed --- lightllm/common/basemodel/attention/fa3/fp.py | 15 ++++++++++++++- .../layer_infer/transformer_layer_infer.py | 11 +++++++++++ .../layer_infer/transformer_layer_infer.py | 11 +++++++++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 8ce12ad335..6a0ecfc1af 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -4,7 +4,14 @@ from typing import Optional, TYPE_CHECKING from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.sgl_utils import flash_attn_with_kvcache -from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_neo + +try: + from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_neo + + HAS_FLASH_ATTN_INTERFACE = True +except ImportError: + flash_attn_with_kvcache_neo = None + HAS_FLASH_ATTN_INTERFACE = False from lightllm.utils.envs_utils import get_env_start_args from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor @@ -97,6 +104,12 @@ def _nomarl_prefill_att( # neo_chat*: image-token bidirectional attention requires flash_attn_interface # (sgl_kernel's flash_attn_with_kvcache does not support image_token_tag). if att_control.image_token_tag is not None: + if not HAS_FLASH_ATTN_INTERFACE: + raise ImportError( + "flash_attn_interface (fa3-neo) is required for image_token_tag bidirectional " + "attention. Install it or set LIGHTLLM_NEO_PREFILL_TRITON_BACKEND=1 to use the " + "triton fallback." + ) extra_kwargs = {"image_token_tag": att_control.image_token_tag} o = flash_attn_with_kvcache_neo( q=q, diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index 1f5f223654..a820ddaaf6 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -12,8 +12,19 @@ import torch.distributed as dist from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.common.basemodel.attention.fa3.fp import HAS_FLASH_ATTN_INTERFACE +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) _USE_TRITON_PREFILL = os.environ.get("LIGHTLLM_NEO_PREFILL_TRITON_BACKEND", "0").strip().lower() in ("1", "true") +if not _USE_TRITON_PREFILL and not HAS_FLASH_ATTN_INTERFACE: + logger.warning( + "flash_attn_interface (fa3-neo) is not installed; falling back to triton prefill backend " + "for neo_chat. Install fa3-neo or set LIGHTLLM_NEO_PREFILL_TRITON_BACKEND=1 to silence " + "this warning." + ) + _USE_TRITON_PREFILL = True class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer): diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index eab7b575bf..66a320fe02 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -12,8 +12,19 @@ import torch.distributed as dist from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.common.basemodel.attention.fa3.fp import HAS_FLASH_ATTN_INTERFACE +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) _USE_TRITON_PREFILL = os.environ.get("LIGHTLLM_NEO_PREFILL_TRITON_BACKEND", "0").strip().lower() in ("1", "true") +if not HAS_FLASH_ATTN_INTERFACE: + logger.warning( + "flash_attn_interface (fa3-neo) is not installed; falling back to triton prefill backend " + "for neo_chat_moe. Install fa3-neo or set LIGHTLLM_NEO_PREFILL_TRITON_BACKEND=1 to silence " + "this warning." + ) + _USE_TRITON_PREFILL = True class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer): From 4959f4e86383480c3ce60c245ec41e697492b95e Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Tue, 21 Apr 2026 08:35:49 +0000 Subject: [PATCH 34/41] verify fa3 image_token_tag --- lightllm/common/basemodel/attention/fa3/fp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py index 6a0ecfc1af..2e63a8dc2e 100644 --- a/lightllm/common/basemodel/attention/fa3/fp.py +++ b/lightllm/common/basemodel/attention/fa3/fp.py @@ -7,6 +7,12 @@ try: from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_neo + import inspect + + # Verify this is the neo-patched FA3 build (with image_token_tag support), + _sig = inspect.signature(flash_attn_with_kvcache_neo) + if "image_token_tag" not in _sig.parameters: + raise ImportError("flash_attn_interface found but missing image_token_tag support (need neo build)") HAS_FLASH_ATTN_INTERFACE = True except ImportError: From 77f224a06d5dbf4a2ec0329841c906c96742db35 Mon Sep 17 00:00:00 2001 From: Charles2530 <2569337619@qq.com> Date: Wed, 22 Apr 2026 22:52:02 +0800 Subject: [PATCH 35/41] add t2i/it2i thinking --- lightllm/server/api_openai.py | 37 ++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 55112d7a4b..0f741f6c79 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -308,7 +308,6 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req finish_reason = finish_reason_dict[sub_req_id] text = "".join(final_output_dict[sub_req_id]) - full_text = text # Handle reasoning content reasoning_text = None @@ -333,14 +332,12 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req tool_calls = None tool_choice = request.tool_choice tools = request.tools - if tool_choice != "none" and any([i in full_text for i in TOOLS_TAG_LIST]): - if finish_reason == "stop": - finish_reason = "tool_calls" + if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]): try: # 为 tool_call_parser 提供默认值 tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3" parser = FunctionCallParser(tools, tool_parser) - full_normal_text, call_info_list = parser.parse_non_stream(full_text) + text, call_info_list = parser.parse_non_stream(text) tool_calls = [] history_tool_calls_cnt = _get_history_tool_calls_cnt(request) for call_info in call_info_list: @@ -358,8 +355,8 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req HTTPStatus.BAD_REQUEST, "Failed to parse fc related info to json format!", ) - if finish_reason == "tool_calls": - text = "" + if tool_calls and finish_reason == "stop": + finish_reason = "tool_calls" chat_message = ChatMessage( role="assistant", content=text if text else "", @@ -610,13 +607,17 @@ def _normalize_image_b64_for_multimodal(image: Union[str, bytes]) -> str: return image -def _apply_image_generation_stop(chat_request: ChatCompletionRequest, image_start_tag: str) -> None: +def _apply_image_generation_stop( + chat_request: ChatCompletionRequest, image_start_tag: str, image_only: bool = False +) -> None: stop = chat_request.stop or [] if isinstance(stop, str): stop = [stop] stop = list(stop) if image_start_tag not in stop: stop.append(image_start_tag) + if chat_request.chat_template_kwargs.get("enable_thinking", False) and image_only: + stop.append("") # TODO: from model config chat_request.stop = stop @@ -683,8 +684,9 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request image_start_tag = g_objs.httpserver_manager.tokenizer.image_start_tag image_tag = g_objs.httpserver_manager.tokenizer.image_tag + image_only = request.modalities == ["image"] - _apply_image_generation_stop(chat_request, image_start_tag) + _apply_image_generation_stop(chat_request, image_start_tag, image_only=image_only) created_time = int(time.time()) @@ -694,7 +696,10 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request x2i_params = X2IParams() x2i_params.init_from_image_config(request.image_config) - if request.modalities == ["image"]: + enable_thinking = request.chat_template_kwargs.get("enable_thinking", False) + print(f"x2i_params: {x2i_params} {image_only} {enable_thinking}", flush=True) + + if image_only and not enable_thinking: return await _chat_completion_image_only(request, raw_request, prompt, multimodal_params, x2i_params) if not request.stream: @@ -706,7 +711,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request completion_tokens = 0 finish_reason: Optional[str] = "stop" group_request_id = None - max_image_gen_num = 15 # TODO: make this configurable + max_image_gen_num = 15 if not image_only else 1 # TODO: make this configurable while max_image_gen_num > 0: max_image_gen_num -= 1 @@ -730,7 +735,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request full_text += output_chunk - if need_call_x2i: + if need_call_x2i or image_only: prompt += output_chunk images = await g_objs.httpserver_manager.generate_image( prompt, x2i_params, multimodal_params.clone(), request=raw_request, input_image_num=input_image_num @@ -741,7 +746,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request response_images.extend(_message_contents_from_raw_images(images, request.image_config.image_type)) for image in images: prompt += image_tag - full_text += image_tag + full_text += image_tag if not image_only else "" multimodal_params.add_image({"type": "base64", "data": _normalize_image_b64_for_multimodal(image)}) else: break @@ -797,7 +802,7 @@ async def stream_result() -> AsyncGenerator[bytes, None]: completion_tokens = 0 finish_reason = None group_request_id = None - max_image_gen_num = 15 # TODO: make this configurable + max_image_gen_num = 15 if not image_only else 1 # TODO: make this configurable while max_image_gen_num > 0: max_image_gen_num -= 1 @@ -857,11 +862,11 @@ async def stream_result() -> AsyncGenerator[bytes, None]: ) yield ("data: " + json.dumps(stream_resp.model_dump(), ensure_ascii=False) + "\n\n").encode("utf-8") - if need_call_x2i: + if need_call_x2i or image_only: prompt += output_chunk images = await g_objs.httpserver_manager.generate_image( - prompt, x2i_params, multimodal_params.clone(), request=raw_request + prompt, x2i_params, multimodal_params.clone(), request=raw_request, input_image_num=input_image_num ) if images is None or len(images) == 0: From 3e841651062f273f9c7c8b682b2d6c0e47d5383b Mon Sep 17 00:00:00 2001 From: Charles2530 <2569337619@qq.com> Date: Wed, 22 Apr 2026 22:52:27 +0800 Subject: [PATCH 36/41] fix build prompt for tool call --- lightllm/server/build_prompt.py | 34 ++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index a33e35e815..21c1cde678 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -45,10 +45,30 @@ def init_tokenizer(args): return +def _normalize_tool_call_arguments(messages: list) -> None: + # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility + # Qwen35's chat template expects arguments to be a dict (uses |items filter) + # but OpenAI format sends arguments as a JSON string + for msg in messages: + tool_calls = msg.get("tool_calls") + if tool_calls and isinstance(tool_calls, list): + for tool_call in tool_calls: + func = tool_call.get("function") + if func and isinstance(func, dict): + args = func.get("arguments") + if isinstance(args, str) and args: + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, TypeError): + pass + + async def build_prompt(request, tools) -> str: global tokenizer # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] + _normalize_tool_call_arguments(messages) + kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings @@ -60,15 +80,7 @@ async def build_prompt(request, tools) -> str: try: input_str = tokenizer.apply_chat_template(**kwargs, tokenize=False, add_generation_prompt=True, tools=tools) - except: - # This except branch will be triggered when the chosen model - # has a different tools input format that is not compatiable - # with openAI's apply_chat_template tool_call format, like Mistral. - tools = [t if "function" in t else {"function": t} for t in tools] - input_str = tokenizer.apply_chat_template( - **kwargs, - tokenize=True, - add_generation_prompt=True, - tools=tools, - ) + except BaseException as e: + logger.error(f"Failed to build prompt: {e}") + raise e return input_str From b2f456510d8b0e3ace921b428b37cfe6c1d0098e Mon Sep 17 00:00:00 2001 From: Charles2530 <2569337619@qq.com> Date: Wed, 22 Apr 2026 23:00:42 +0800 Subject: [PATCH 37/41] dynamic_resolution, height/weight, seed, image_size --- lightllm/server/api_models.py | 93 +++++++++++-------- lightllm/server/core/objs/x2i_params.py | 11 ++- lightllm/server/function_call_parser.py | 5 +- lightllm/server/httpserver/manager.py | 3 +- .../server/x2i_server/lightx2v/adapter.py | 17 ++-- lightllm/server/x2i_server/manager.py | 2 +- 6 files changed, 78 insertions(+), 53 deletions(-) diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index c2fce7dd95..d6ceebfea8 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -121,7 +121,7 @@ class CompletionRequest(BaseModel): prompt: Union[str, List[str], List[int], List[List[int]]] suffix: Optional[str] = None max_tokens: Optional[int] = Field( - default=16384, deprecated="max_tokens is deprecated, please use max_completion_tokens instead" + default=256000, deprecated="max_tokens is deprecated, please use max_completion_tokens instead" ) max_completion_tokens: Optional[int] = None temperature: Optional[float] = 1.0 @@ -197,7 +197,7 @@ class ChatCompletionRequest(BaseModel): stream_options: Optional[StreamOptions] = None stop: Optional[Union[str, List[str]]] = None max_tokens: Optional[int] = Field( - default=16384, deprecated="max_tokens is deprecated, please use max_completion_tokens instead" + default=256000, deprecated="max_tokens is deprecated, please use max_completion_tokens instead" ) max_completion_tokens: Optional[int] = None presence_penalty: Optional[float] = 0.0 @@ -389,7 +389,7 @@ def ensure_id_is_str(cls, v): "21:9", ] -ImageSize: TypeAlias = Literal["0.5K", "1K", "2K", "4K"] +ImageSize: TypeAlias = Literal["0.5K", "1K", "1.5K", "2K", "4K"] Modality: TypeAlias = Literal["text", "image", "audio"] @@ -398,8 +398,10 @@ def ensure_id_is_str(cls, v): class ImageConfig(BaseModel): aspect_ratio: AspectRatio = "1:1" - image_size: ImageSize = "1K" + image_size: ImageSize = "1.5K" image_type: ImageType = "jpeg" + height: Optional[int] = -1 + width: Optional[int] = -1 # X2I / diffusion sampling (optional; server defaults apply when omitted) steps: Optional[int] = None guidance_scale: Optional[float] = None @@ -407,41 +409,50 @@ class ImageConfig(BaseModel): seed: Optional[int] = None num_images: Optional[int] = None cfg_norm: Optional[Literal["none", "cfg_zero_star", "global", "text_channel", "channel"]] = None - - # Mapping to actual resolutions (base resolution for 1K) + dynamic_resolution: Optional[bool] = True _aspect_ratio_to_resolution: ClassVar[dict] = { - "1:1": (1024, 1024), - "2:3": (832, 1248), - "3:2": (1248, 832), - "3:4": (864, 1184), - "4:3": (1184, 864), - "4:5": (896, 1152), - "5:4": (1152, 896), - "9:16": (768, 1344), - "16:9": (1920, 1080), - "21:9": (1536, 672), - } - - _size_multiplier: ClassVar[dict] = { - "0.5K": 0.5, - "1K": 1.0, - "2K": 2.0, - "4K": 4.0, + "1:1": {"1K": (1024, 1024), "1.5K": (1536, 1536), "2K": (2048, 2048)}, + "16:9": {"1.5K": (2048, 1152), "2K": (2720, 1536)}, + "9:16": {"1.5K": (1152, 2048), "2K": (1536, 2720)}, + "3:2": {"1.5K": (1888, 1248), "2K": (2496, 1664)}, + "2:3": {"1.5K": (1248, 1888), "2K": (1664, 2496)}, + "4:3": {"1.5K": (1760, 1312), "2K": (2368, 1760)}, + "3:4": {"1.5K": (1312, 1760), "2K": (1760, 2368)}, + "1:2": {"1.5K": (1088, 2144), "2K": (1440, 2880)}, + "2:1": {"1.5K": (2144, 1088), "2K": (2880, 1440)}, + "1:3": {"1.5K": (864, 2592), "2K": (1152, 3456)}, + "3:1": {"1.5K": (2592, 864), "2K": (3456, 1152)}, } + _size_set: ClassVar[set[str]] = {"1.5K", "2K"} - @field_validator("aspect_ratio") + @field_validator("image_size", mode="before") @classmethod - def validate_aspect_ratio(cls, v): - if v not in cls._aspect_ratio_to_resolution: - raise ValueError(f"Unsupported aspect ratio: {v}") + def normalize_image_size(cls, v): + if isinstance(v, str): + return v.strip().upper() return v - @field_validator("image_size") - @classmethod - def validate_image_size(cls, v): - if v not in cls._size_multiplier: - raise ValueError(f"Unsupported image size: {v}") - return v + @model_validator(mode="after") + def validate_resolution_config(self): + has_custom_height = self.height is not None and self.height > 0 + has_custom_width = self.width is not None and self.width > 0 + has_any_custom = (self.height is not None and self.height != -1) or ( + self.width is not None and self.width != -1 + ) + + # If custom resolution is provided, require both height/width and both must be positive. + if has_any_custom: + if not has_custom_height or not has_custom_width: + raise ValueError("height and width must both be provided as positive integers") + self.dynamic_resolution = False + return self + + # Otherwise, validate ratio and logical image size. + if self.aspect_ratio not in self._aspect_ratio_to_resolution: + raise ValueError(f"Unsupported aspect ratio: {self.aspect_ratio}") + if self.image_size not in self._size_set: + raise ValueError(f"Unsupported image size: {self.image_size}") + return self @field_validator("image_type") @classmethod @@ -452,16 +463,16 @@ def validate_image_type(cls, v): def get_resolution(self): """Return scaled resolution (width, height)""" - base = self._aspect_ratio_to_resolution[self.aspect_ratio] - if base is None: - return None # extended ratios don't have fixed base - - scale = self._size_multiplier[self.image_size] - w, h = base - w, h = int(w * scale), int(h * scale) from lightllm.models.neo_chat_moe.vision_process import smart_resize - h, w = smart_resize(h, w, factor=32, min_pixels=512 * 512, max_pixels=2048 * 2048) + print(f"self.height: {self.height}, self.width: {self.width}", flush=True) + if self.height > -1 and self.width > -1: + w, h = self.width, self.height + else: + base = self._aspect_ratio_to_resolution[self.aspect_ratio][self.image_size] + w, h = base + + h, w = smart_resize(h, w, factor=32, min_pixels=1024 * 1024, max_pixels=2048 * 2048) return w, h diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py index 281e5f66d2..cac7b9abdd 100644 --- a/lightllm/server/core/objs/x2i_params.py +++ b/lightllm/server/core/objs/x2i_params.py @@ -54,9 +54,10 @@ class X2IParams(ctypes.Structure): _image_guidance_scale: float = 1.0 _seed: int = 42 _num_images: int = 1 - _cfg_norm: CfgNormType = CfgNormType.NONE - _cfg_interval: float = (-1, 2) + _cfg_norm: CfgNormType = CfgNormType.GLOBAL + _cfg_interval: float = (0, 1) _timestep_shift: float = 3.0 + _dynamic_resolution: bool = True def init(self, **kwargs): def _get(key, default): @@ -76,9 +77,11 @@ def _get(key, default): self.past_kvcache = PastKVCachePageList() self.past_kvcache_text = PastKVCachePageList() self.past_kvcache_img = PastKVCachePageList() + self.dynamic_resolution = _get("dynamic_resolution", X2IParams._dynamic_resolution) self.total_prompt_tokens = 0 self.request_id = 0 self.has_updated_hw = False + self.first_image = True def init_from_image_config(self, image_config: Any) -> None: """从 HTTP `image_config`(api_models.ImageConfig)填充,与 `init(**kwargs)` 共用默认值逻辑。""" @@ -98,6 +101,8 @@ def init_from_image_config(self, image_config: Any) -> None: kwargs["seed"] = image_config.seed if image_config.num_images is not None: kwargs["num_images"] = image_config.num_images + if image_config.dynamic_resolution is not None: + kwargs["dynamic_resolution"] = image_config.dynamic_resolution if image_config.cfg_norm is not None: for e in CfgNormType: if e.as_str() == image_config.cfg_norm: @@ -106,6 +111,8 @@ def init_from_image_config(self, image_config: Any) -> None: self.init(**kwargs) def update_hw(self, width: int, height: int): + if not self.dynamic_resolution: + return if self.has_updated_hw: return from lightllm.models.neo_chat_moe.vision_process import smart_resize diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 3a8fddf744..e9700289c0 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -11,7 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast import json +import os import orjson import logging import re @@ -28,6 +30,7 @@ from .api_models import Tool logger = logging.getLogger(__name__) +ENABLE_TOOL_NAME_CHECK = os.getenv("LIGHTLLM_ENABLE_TOOL_NAME_CHECK", "False").upper() in ["ON", "TRUE", "1"] TOOLS_TAG_LIST = [ "<|plugin|>", @@ -155,7 +158,7 @@ def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]: results = [] for act in action: name = act.get("name") - if name and name in tool_indices: + if name and (not ENABLE_TOOL_NAME_CHECK or name in tool_indices): results.append( ToolCallItem( tool_index=-1, # Caller should update this based on the actual tools array called diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 99ed011cb1..05db8c335e 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -487,7 +487,7 @@ async def generation_wrapper(prompt, sample, multimodal, request): else: # call t2i prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt, input_image_num) - # logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}") + logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}") (con_gen, uncon_gen) = await asyncio.gather( *[ generation_wrapper(prompt_condition, sample_params, multimodal_params, request), @@ -514,6 +514,7 @@ async def generation_wrapper(prompt, sample, multimodal, request): assert req_status.response is not None self.req_id_to_x2i_reqs.pop(x2i_req_id, None) + generation_params.first_image = False return req_status.response.images diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py index b738b00d7d..95a795f617 100644 --- a/lightllm/server/x2i_server/lightx2v/adapter.py +++ b/lightllm/server/x2i_server/lightx2v/adapter.py @@ -55,9 +55,7 @@ def _init_pipeline(self): support_tasks=["t2i", "i2i"], ) self.pipe.create_generator(config_json=self.args.x2v_gen_model_config) - self.pipe.modify_config({ - "load_kv_cache_in_pipeline_for_debug": False, - "save_result_for_debug": False}) + self.pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False}) async def run(self): while True: @@ -95,13 +93,16 @@ async def _process(self, param: X2IParams): timestep_shift=param.timestep_shift, ) past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i( - param.past_kvcache.get_all(), param.past_kvcache.token_len) + param.past_kvcache.get_all(), param.past_kvcache.token_len + ) past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i( - param.past_kvcache_text.get_all(), param.past_kvcache_text.token_len) + param.past_kvcache_text.get_all(), param.past_kvcache_text.token_len + ) past_kv_cache_img = None if not is_t2i: past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i( - param.past_kvcache_img.get_all(), param.past_kvcache_img.token_len) + param.past_kvcache_img.get_all(), param.past_kvcache_img.token_len + ) dist.barrier() # ensure all workers have got the kv cache before generation starts @@ -122,8 +123,10 @@ async def _process(self, param: X2IParams): past_kv_cache_text, past_kv_cache_img, ) + seed = param.seed if param.first_image else None + logger.info(f"seed: {seed} {param.seed} first_image: {param.first_image}") image = self.pipe.generate( - seed=param.seed + param.past_kvcache.img_len, + seed=seed, save_result_path="", target_shape=[param.height, param.width], ) diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index aef9b0ab88..c48f283e79 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -108,7 +108,7 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams for i in range(param.num_images): self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text) image = self.gen_pipe.generate( - seed=param.seed + param.past_kvcache.img_len + i, + seed=param.seed if param.first_image else None, save_result_path="", # 返回base64,不需要指定路径了 target_shape=[param.height, param.width], # Height, Width ) From c99e4b35e0436f6650431e671674765c4cea6d99 Mon Sep 17 00:00:00 2001 From: Weichao Luo Date: Thu, 23 Apr 2026 08:14:24 +0000 Subject: [PATCH 38/41] fix lint. --- .../triton_kernel/kv_cache_offload.py | 32 +- .../layer_infer/transformer_layer_infer.py | 1 - lightllm/models/neo_chat/model.py | 12 +- lightllm/models/neo_chat_moe/infer_struct.py | 1 - lightllm/models/neo_chat_moe/model.py | 12 +- lightllm/server/api_cli.py | 3 +- lightllm/server/api_lightllm.py | 2 +- lightllm/server/api_openai.py | 6 +- lightllm/server/api_start.py | 1 + lightllm/server/core/objs/start_args_type.py | 2 +- .../core/objs/token_chunck_hash_list.py | 6 +- .../model_infer/mode_backend/base_backend.py | 4 +- .../model_infer/mode_backend/past_kv_cache.py | 27 +- .../naive/configuration_neo_chat.py | 26 +- .../x2i_server/naive/configuration_neo_vit.py | 40 +- .../x2i_server/naive/modeling_fm_modules.py | 136 +++--- .../x2i_server/naive/modeling_neo_chat.py | 403 +++++++++++------- .../x2i_server/naive/modeling_neo_vit.py | 71 +-- .../test_context_attention_fwd_neo.py | 4 +- .../att/prefill_att/test_fa3_neo.py | 18 +- .../kv_trans_kernel/test_kv_trans_from_gpu.py | 65 ++- 21 files changed, 481 insertions(+), 391 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py index 01d32e908c..0cff35864c 100644 --- a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py +++ b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py @@ -706,7 +706,6 @@ def load_cpu_kv_to_gpu( return - @triton.jit def _offload_gpu_kv_to_cpu_for_x2i( token_indexes_ptr, @@ -800,8 +799,12 @@ def _offload_gpu_kv_to_cpu_for_x2i( + cpu_k_head_index * cpu_scale_stride3 + head_dim_range[None, :] ) - tl.store(cpu_scale_ptr, gpu_scale_data, mask=token_scale_mask, cache_modifier=".wt",) - + tl.store( + cpu_scale_ptr, + gpu_scale_data, + mask=token_scale_mask, + cache_modifier=".wt", + ) for v_head_index in range(gpu_v_head_num): gpu_v_head_index = v_head_index + gpu_v_start_head_index @@ -842,8 +845,12 @@ def _offload_gpu_kv_to_cpu_for_x2i( + cpu_v_head_index * cpu_scale_stride3 + head_dim_range[None, :] ) - tl.store(cpu_scale_ptr, gpu_scale_data, mask=token_scale_mask, cache_modifier=".wt",) - + tl.store( + cpu_scale_ptr, + gpu_scale_data, + mask=token_scale_mask, + cache_modifier=".wt", + ) @torch.no_grad() @@ -955,7 +962,7 @@ def offload_gpu_kv_to_cpu_for_x2i( assert token_block_size == triton.next_power_of_2(token_block_size) page_num = page_indexes.shape[0] - grid = (grid_num, ) + grid = (grid_num,) num_warps = 4 num_stages = 1 HAS_SCALE = gpu_kv_cache_scale is not None and cpu_kv_cache_scale is not None @@ -968,14 +975,13 @@ def offload_gpu_kv_to_cpu_for_x2i( gpu_scale_stride = [0 for _ in range(5)] cpu_scale_stride = [0 for _ in range(5)] - _offload_gpu_kv_to_cpu_for_x2i[grid]( - token_indexes_ptr = token_indexes, - gpu_kv_cache_ptr = gpu_kv_cache, - gpu_stride0 = gpu_kv_cache.stride(0), - gpu_stride1 = gpu_kv_cache.stride(1), - gpu_stride2 = gpu_kv_cache.stride(2), - gpu_kv_cache_scale_ptr = gpu_kv_cache_scale, + token_indexes_ptr=token_indexes, + gpu_kv_cache_ptr=gpu_kv_cache, + gpu_stride0=gpu_kv_cache.stride(0), + gpu_stride1=gpu_kv_cache.stride(1), + gpu_stride2=gpu_kv_cache.stride(2), + gpu_kv_cache_scale_ptr=gpu_kv_cache_scale, gpu_scale_stride0=gpu_scale_stride[0], gpu_scale_stride1=gpu_scale_stride[1], gpu_scale_stride2=gpu_scale_stride[2], diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index a820ddaaf6..c5c4f3c343 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -97,7 +97,6 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC def _context_attention_kernel( self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None ) -> torch.Tensor: - if _USE_TRITON_PREFILL: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py index e17ad31fa1..80621034ae 100644 --- a/lightllm/models/neo_chat/model.py +++ b/lightllm/models/neo_chat/model.py @@ -46,12 +46,12 @@ def _init_inferstate_cls(self): pass def _init_att_backend(self): - self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class( - index=0, priority_list=["fa3"] - )(model=self) - self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class( - index=0, priority_list=["fa3"] - )(model=self) + self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])( + model=self + ) + self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0, priority_list=["fa3"])( + model=self + ) def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py index b8e6483def..1693bcb964 100644 --- a/lightllm/models/neo_chat_moe/infer_struct.py +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -20,7 +20,6 @@ def __init__(self): def init_some_extra_state(self, model: LlamaTpPartModel): LlamaInferStateInfo.init_some_extra_state(self, model) if self.is_prefill: - bsz = self.b_q_seq_len.shape[0] self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda( non_blocking=True ) diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py index 7ae3cd27c0..fb0cb988d2 100644 --- a/lightllm/models/neo_chat_moe/model.py +++ b/lightllm/models/neo_chat_moe/model.py @@ -188,12 +188,12 @@ def _init_inferstate_cls(self): pass def _init_att_backend(self): - self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class( - index=0, priority_list=["fa3"] - )(model=self) - self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class( - index=0, priority_list=["fa3"] - )(model=self) + self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])( + model=self + ) + self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0, priority_list=["fa3"])( + model=self + ) def _init_config(self): with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index bc811ff17b..e2aa4f79cf 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -352,7 +352,8 @@ def make_argument_parser() -> argparse.ArgumentParser: type=str, choices=["colocate", "separate"], default="colocate", - help="Deployment mode for the x2i server. 'colocate' means the x2i server will run on the same gpus as the llm server, ", + help="Deployment mode for the x2i server. 'colocate' means the x2i server will " + "run on the same gpus as the llm server, ", ) parser.add_argument( "--x2i_use_naive_impl", diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index b20670e47c..7963423d8f 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -165,4 +165,4 @@ async def lightllm_generate_image(request: Request, httpserver_manager: HttpServ results = await httpserver_manager.generate_image(prompt, generation_params, multimodal_params, request=request) - return Response(content=json.dumps({"images": results}, ensure_ascii=False).encode("utf-8")) \ No newline at end of file + return Response(content=json.dumps({"images": results}, ensure_ascii=False).encode("utf-8")) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 376cb6c397..d137bba162 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -185,9 +185,7 @@ def _get_images_and_audios(request: ChatCompletionRequest): # Local file path with file:// prefix file_path = img[7:] # Remove "file://" prefix with open(file_path, "rb") as f: - images.append( - {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")} - ) + images.append({"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")}) else: raise ValueError( "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." @@ -222,7 +220,7 @@ def _get_tools(request: ChatCompletionRequest): tools = [item.function.model_dump() for item in request.tools] return tools - + def _split_tool_argument_delta(arguments: Optional[str]) -> List[str]: """Split a complete JSON argument string into OpenAI-style deltas.""" if not arguments: diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3cdc84c68c..5be30526a5 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -413,6 +413,7 @@ def normal_or_p_d_start(args): if args.enable_multimodal_x2i: from .x2i_server.manager import start_x2i_process, setup_devices + origin_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) setup_devices(args) process_manager.start_submodule_processes( diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 73b8b0060e..84253ab29e 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -187,7 +187,7 @@ class StartArgs: metric_port: int = field(default=None) multinode_httpmanager_port: int = field(default=12345) multi_level_kv_cache_port: int = field(default=None) - + # multi_modal_x2i enable_multimodal_x2i: bool = field(default=False) x2i_port: int = field(default=None) diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py index de43cc4cc6..6927cc3488 100644 --- a/lightllm/server/core/objs/token_chunck_hash_list.py +++ b/lightllm/server/core/objs/token_chunck_hash_list.py @@ -91,10 +91,10 @@ def get_all(self): class PastKVCachePageList(CpuCachePageList): _pack_ = 4 - _fields_ = CpuCachePageList._fields_ +[ + _fields_ = CpuCachePageList._fields_ + [ ("token_len", ctypes.c_int), # 对应的token数量 ("img_tokens", ctypes.c_int), - ("img_len", ctypes.c_int) + ("img_len", ctypes.c_int), ] def __init__(self, token_len: int = 0): @@ -107,4 +107,4 @@ def get_compressed_len(self): return self.token_len - self.img_tokens + self.img_len def __repr__(self): - return f"(token_len={self.token_len}, img_tokens={self.img_tokens}, img_len={self.img_len})" \ No newline at end of file + return f"(token_len={self.token_len}, img_tokens={self.img_tokens}, img_len={self.img_len})" diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index df8dfe3316..7a4d851603 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -48,6 +48,7 @@ from .multi_level_kv_cache import MultiLevelKvCacheModule from .past_kv_cache import PastKVCacheModule + class ModeBackend: def __init__(self) -> None: self.shm_req_manager = ShmReqManager() @@ -658,7 +659,8 @@ def _get_classed_reqs( if self.args.enable_multimodal_x2i: true_finished_reqs = self.past_kv_cache_module.offload_finished_reqs_to_past_kv_cache( - finished_reqs=true_finished_reqs) + finished_reqs=true_finished_reqs + ) g_infer_context.filter_reqs(finished_reqs=true_finished_reqs) g_infer_context.pause_reqs(wait_pause_reqs, is_master_in_dp=self.is_master_in_dp) diff --git a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py index e6457ded9a..79f9e62a2f 100644 --- a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py @@ -22,9 +22,11 @@ class TransTask: sync_event: torch.cuda.Event buffer_index: int + class PastKVCacheModule(object): def __init__(self, backend): from .base_backend import ModeBackend + self.backend: ModeBackend = backend self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=False) self.page_index_buffer = torch.empty((1024 * LIGHTLLM_TOKEN_HASH_LIST_SIZE,), dtype=torch.int32, device="cuda") @@ -32,7 +34,6 @@ def __init__(self, backend): self.past_kv_cache_task: Deque[TransTask] = deque() self.sync_task_status_group = create_new_group_for_current_dp("gloo") - @lru_cache() def need_sync_compute_stream(self) -> bool: """ @@ -71,8 +72,9 @@ def offload_finished_reqs_to_past_kv_cache(self, finished_reqs: List[InferReq]) if req.past_kv_cache_task_status.is_running(): continue - assert req.past_kv_cache_task_status.is_not_started(), \ - f"req {req.req_id} has invalid past kv cache task status {req.past_kv_cache_task_status}" + assert ( + req.past_kv_cache_task_status.is_not_started() + ), f"req {req.req_id} has invalid past kv cache task status {req.past_kv_cache_task_status}" if self.need_sync_compute_stream(): g_infer_context.get_overlap_stream().synchronize() @@ -95,7 +97,9 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]: start = free_index * LIGHTLLM_TOKEN_HASH_LIST_SIZE end = start + req.shm_req.past_kv_cache_page_indexes.size - page_indexes = torch.tensor(req.shm_req.past_kv_cache_page_indexes.get_all(), dtype=torch.int32, device='cpu', pin_memory=True) + page_indexes = torch.tensor( + req.shm_req.past_kv_cache_page_indexes.get_all(), dtype=torch.int32, device="cpu", pin_memory=True + ) num_tokens = req.shm_req.input_len assert req.cur_kv_len >= num_tokens @@ -104,13 +108,12 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]: cuda_page_indexes = self.page_index_buffer[start:end] cuda_page_indexes.copy_(page_indexes) - token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0: num_tokens] + token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0:num_tokens] mem_manager = self.backend.model.mem_manager - if hasattr(mem_manager, "scale_buffer") and mem_manager.scale_buffer is not None: cpu_cache_meta = self.past_kv_cache_client.kv_cache_tensor_meta - cpu_kv_cache = self.past_kv_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0:cpu_cache_meta.head_dim] + cpu_kv_cache = self.past_kv_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0 : cpu_cache_meta.head_dim] cpu_kv_cache_scale = self.past_kv_cache_client.cpu_kv_cache_tensor[ :, :, :, :, cpu_cache_meta.head_dim : ].view(mem_manager.scale_buffer.dtype) @@ -137,11 +140,7 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]: sync_event.wait(g_infer_context.get_overlap_stream()) # sync_event.synchronize() req.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING - return TransTask( - req_obj=req, - sync_event=sync_event, - buffer_index=free_index - ) + return TransTask(req_obj=req, sync_event=sync_event, buffer_index=free_index) def update_past_kv_cache_task_states(self): trans_ok_tasks = [] @@ -157,7 +156,7 @@ def update_past_kv_cache_task_states(self): dist.all_reduce(ok_tasks_num, op=dist.ReduceOp.MIN, group=self.sync_task_status_group) if ok_tasks_num.item() > 0: - finished, unfinished = trans_ok_tasks[:ok_tasks_num.item()], trans_ok_tasks[ok_tasks_num.item():] + finished, unfinished = trans_ok_tasks[: ok_tasks_num.item()], trans_ok_tasks[ok_tasks_num.item() :] self.past_kv_cache_task.extendleft(reversed(unfinished)) for task in finished: task.req_obj.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED @@ -171,4 +170,4 @@ def update_past_kv_cache_task_states(self): shm_req.candetoken_out_len = shm_req.shm_cur_output_len else: if len(trans_ok_tasks) > 0: - self.past_kv_cache_task.extendleft(reversed(trans_ok_tasks)) \ No newline at end of file + self.past_kv_cache_task.extendleft(reversed(trans_ok_tasks)) diff --git a/lightllm/server/x2i_server/naive/configuration_neo_chat.py b/lightllm/server/x2i_server/naive/configuration_neo_chat.py index 44171917a5..356284d7b0 100644 --- a/lightllm/server/x2i_server/naive/configuration_neo_chat.py +++ b/lightllm/server/x2i_server/naive/configuration_neo_chat.py @@ -18,7 +18,7 @@ def __init__(self, rope_theta_hw=10000.0, max_position_embeddings_hw=10000, **kw class NEOChatConfig(PretrainedConfig): - model_type = 'neo_chat' + model_type = "neo_chat" is_composition = True def __init__( @@ -34,13 +34,13 @@ def __init__( super().__init__(**kwargs) if vision_config is None: - vision_config = {'architectures': ['NEOVisionModel']} - logger.info('vision_config is None. Initializing the NEOVisionConfig with default values.') + vision_config = {"architectures": ["NEOVisionModel"]} + logger.info("vision_config is None. Initializing the NEOVisionConfig with default values.") if llm_config is None: - llm_config = {'architectures': ['Qwen3ForCausalLM']} - logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).') - assert 'architectures' in llm_config, "Should specify architecture in llm_config" + llm_config = {"architectures": ["Qwen3ForCausalLM"]} + logger.info("llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).") + assert "architectures" in llm_config, "Should specify architecture in llm_config" if isinstance(vision_config, dict): self.vision_config = NEOVisionConfig(**vision_config) @@ -66,12 +66,12 @@ def to_dict(self): `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, """ output = copy.deepcopy(self.__dict__) - output['vision_config'] = self.vision_config.to_dict() - output['llm_config'] = self.llm_config.to_dict() - output['model_type'] = self.__class__.model_type - output['use_backbone_lora'] = self.use_backbone_lora - output['use_llm_lora'] = self.use_llm_lora - output['downsample_ratio'] = self.downsample_ratio - output['template'] = self.template + output["vision_config"] = self.vision_config.to_dict() + output["llm_config"] = self.llm_config.to_dict() + output["model_type"] = self.__class__.model_type + output["use_backbone_lora"] = self.use_backbone_lora + output["use_llm_lora"] = self.use_llm_lora + output["downsample_ratio"] = self.downsample_ratio + output["template"] = self.template return output diff --git a/lightllm/server/x2i_server/naive/configuration_neo_vit.py b/lightllm/server/x2i_server/naive/configuration_neo_vit.py index 02837fea41..2e0d09356f 100644 --- a/lightllm/server/x2i_server/naive/configuration_neo_vit.py +++ b/lightllm/server/x2i_server/naive/configuration_neo_vit.py @@ -9,26 +9,26 @@ class NEOVisionConfig(PretrainedConfig): - model_type = 'neo_vision' + model_type = "neo_vision" def __init__( - self, - num_channels=3, - patch_size=16, - hidden_size=1024, - llm_hidden_size=2048, - downsample_ratio=0.5, - rope_theta_vision=10000.0, - max_position_embeddings_vision=10000, - min_pixels=65536, - max_pixels=4194304, - **kwargs, + self, + num_channels=3, + patch_size=16, + hidden_size=1024, + llm_hidden_size=2048, + downsample_ratio=0.5, + rope_theta_vision=10000.0, + max_position_embeddings_vision=10000, + min_pixels=65536, + max_pixels=4194304, + **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size - self.llm_hidden_size = llm_hidden_size, - self.downsample_ratio = downsample_ratio, + self.llm_hidden_size = (llm_hidden_size,) + self.downsample_ratio = (downsample_ratio,) self.rope_theta_vision = rope_theta_vision self.max_position_embeddings_vision = max_position_embeddings_vision self.num_channels = num_channels @@ -37,16 +37,16 @@ def __init__( self.max_pixels = max_pixels @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - if 'vision_config' in config_dict: - config_dict = config_dict['vision_config'] + if "vision_config" in config_dict: + config_dict = config_dict["vision_config"] - if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type: + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: logger.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) - return cls.from_dict(config_dict, **kwargs) \ No newline at end of file + return cls.from_dict(config_dict, **kwargs) diff --git a/lightllm/server/x2i_server/naive/modeling_fm_modules.py b/lightllm/server/x2i_server/naive/modeling_fm_modules.py index 56654b2580..ff36ddf37a 100644 --- a/lightllm/server/x2i_server/naive/modeling_fm_modules.py +++ b/lightllm/server/x2i_server/naive/modeling_fm_modules.py @@ -5,11 +5,14 @@ from functools import lru_cache from torch.utils.checkpoint import checkpoint + + def modulate(x, shift, scale=None): if shift is None: return x * (1 + scale) return x * (1 + scale) + shift + class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() @@ -20,6 +23,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return output * self.weight + class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. @@ -59,8 +63,8 @@ def forward(self, t): t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) return t_emb -class ResBlock(nn.Module): +class ResBlock(nn.Module): def __init__(self, channels, mlp_ratio=1.0): super().__init__() self.channels = channels @@ -81,6 +85,7 @@ def forward(self, x, y): h = self.mlp(h) return x + gate_mlp * h + # class FinalLayer(nn.Module): # def __init__(self, model_channels, out_channels): @@ -162,8 +167,8 @@ def forward(self, x, y): # return self.final_layer(x, y) -class FlowMatchingHead(nn.Module): +class FlowMatchingHead(nn.Module): def __init__(self, input_dim, out_dim, dim=1536, layers=12, mlp_ratio=1.0): super(FlowMatchingHead, self).__init__() self.net = SimpleMLPAdaLN(input_dim=input_dim, out_dim=out_dim, dim=dim, layers=layers, mlp_ratio=mlp_ratio) @@ -181,7 +186,7 @@ def forward(self, x, t): return x -def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0): +def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float = 10000.0, scale=16.0): # assert H * H == end # flat_patch_pos = torch.linspace(-1, 1, end) # N = end x_pos = torch.linspace(0, scale, width) @@ -189,22 +194,23 @@ def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 100 y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij") y_pos = y_pos.reshape(-1) x_pos = x_pos.reshape(-1) - freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 - x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 - y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 + freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4 + x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4 + y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4 x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) - freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 - freqs_cis = freqs_cis.reshape(height*width, -1) + freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2 + freqs_cis = freqs_cis.reshape(height * width, -1) return freqs_cis + class NerfEmbedder(nn.Module): def __init__(self, in_channels, hidden_size_input, max_freqs): super().__init__() self.max_freqs = max_freqs self.hidden_size_input = hidden_size_input self.embedder = nn.Sequential( - nn.Linear(in_channels+max_freqs**2, hidden_size_input, bias=True), + nn.Linear(in_channels + max_freqs ** 2, hidden_size_input, bias=True), ) @lru_cache @@ -213,7 +219,6 @@ def fetch_pos(self, patch_size, device, dtype): pos = pos[None, :, :].to(device=device, dtype=dtype) return pos - def forward(self, inputs): B, P2, C = inputs.shape patch_size = int(P2 ** 0.5) @@ -225,6 +230,7 @@ def forward(self, inputs): inputs = self.embedder(inputs) return inputs + class SimpleMLPAdaLN(nn.Module): """ The MLP for Diffusion Loss. @@ -243,7 +249,7 @@ def __init__( z_channels, num_res_blocks, patch_size, - grad_checkpointing=False + grad_checkpointing=False, ): super().__init__() @@ -254,15 +260,17 @@ def __init__( self.grad_checkpointing = grad_checkpointing self.patch_size = patch_size - self.cond_embed = nn.Linear(z_channels, patch_size**2*model_channels) + self.cond_embed = nn.Linear(z_channels, patch_size ** 2 * model_channels) self.input_proj = nn.Linear(in_channels, model_channels) res_blocks = [] for i in range(num_res_blocks): - res_blocks.append(ResBlock( - model_channels, - )) + res_blocks.append( + ResBlock( + model_channels, + ) + ) self.res_blocks = nn.ModuleList(res_blocks) self.final_layer = FinalLayer(model_channels, out_channels) @@ -275,6 +283,7 @@ def _basic_init(module): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) + self.apply(_basic_init) # Zero-out adaLN modulation layers @@ -297,7 +306,7 @@ def forward(self, x, c): x = self.input_proj(x) c = self.cond_embed(c) - y = c.reshape(-1, self.patch_size**2, self.model_channels) + y = c.reshape(-1, self.patch_size ** 2, self.model_channels) for block in self.res_blocks: x = block(x, y) @@ -309,6 +318,7 @@ class FinalLayer(nn.Module): """ The final layer adopted from DiT. """ + def __init__(self, model_channels, out_channels): super().__init__() self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) @@ -319,6 +329,7 @@ def forward(self, x): x = self.linear(x) return x + ################################################################################# # Sine/Cosine Positional Embedding Functions # ################################################################################# @@ -363,7 +374,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) + omega = 1.0 / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product @@ -374,6 +385,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb + # -------------------------------------------------------- # Interpolate position embeddings for high-resolution # References: @@ -385,11 +397,11 @@ def interpolate_pos_embed(model_path, pe_key: str = "gen_pos_embed", new_len: in pos_embed_1d = state_dict[pe_key] _, ori_len, embed_dim = pos_embed_1d.shape - ori_size = int(ori_len**0.5) - new_size = int(new_len**0.5) + ori_size = int(ori_len ** 0.5) + new_size = int(new_len ** 0.5) if ori_size != new_size: - logger.info("Position interpolate from %dx%d to %dx%d" % (ori_size, ori_size, new_size, new_size)) + # logger.info("Position interpolate from %dx%d to %dx%d" % (ori_size, ori_size, new_size, new_size)) pos_embed_2d = pos_embed_1d.reshape(-1, ori_size, ori_size, embed_dim).permute(0, 3, 1, 2) pos_embed_2d = torch.nn.functional.interpolate( pos_embed_2d, size=(new_size, new_size), mode="bicubic", align_corners=False @@ -399,15 +411,13 @@ def interpolate_pos_embed(model_path, pe_key: str = "gen_pos_embed", new_len: in torch.save(state_dict, model_path) + class PositionEmbedding(nn.Module): def __init__(self, max_num_patch_per_side, hidden_size): super().__init__() self.max_num_patch_per_side = max_num_patch_per_side self.hidden_size = hidden_size - self.pos_embed = nn.Parameter( - torch.zeros(max_num_patch_per_side ** 2, hidden_size), - requires_grad=False - ) + self.pos_embed = nn.Parameter(torch.zeros(max_num_patch_per_side ** 2, hidden_size), requires_grad=False) self._init_weights() def _init_weights(self): @@ -457,37 +467,37 @@ def __init__(self, hidden_dim=4096, out_channels=3): # self.proj = nn.Linear(hidden_dim, 1024) # self.act = nn.SiLU() - self.up_blocks = nn.ModuleList([ - nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), - nn.Conv2d(hidden_dim, 512, kernel_size=3, padding=1), - nn.GroupNorm(32, 512), - nn.SiLU() - ), - nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), - nn.Conv2d(512, 256, kernel_size=3, padding=1), - nn.GroupNorm(32, 256), - nn.SiLU() - ), - nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), - nn.Conv2d(256, 64, kernel_size=3, padding=1), - nn.GroupNorm(32, 64), - nn.SiLU() - ), - nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), - nn.Conv2d(64, 32, kernel_size=3, padding=1), - nn.GroupNorm(16, 32), - nn.SiLU() - ), - nn.Sequential( - nn.Upsample(scale_factor=2, mode='nearest'), - nn.Conv2d(32, 16, kernel_size=3, padding=1), - nn.SiLU() - ) - ]) + self.up_blocks = nn.ModuleList( + [ + nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(hidden_dim, 512, kernel_size=3, padding=1), + nn.GroupNorm(32, 512), + nn.SiLU(), + ), + nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(512, 256, kernel_size=3, padding=1), + nn.GroupNorm(32, 256), + nn.SiLU(), + ), + nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(256, 64, kernel_size=3, padding=1), + nn.GroupNorm(32, 64), + nn.SiLU(), + ), + nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.GroupNorm(16, 32), + nn.SiLU(), + ), + nn.Sequential( + nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.SiLU() + ), + ] + ) self.out_conv = nn.Conv2d(16, out_channels, kernel_size=3, padding=1) @@ -520,8 +530,8 @@ def __init__(self): def forward(self, x): # x shape: [B, 4096, H/32, W/32] - x = self.ps1(self.act1(self.conv1(x))) # -> [B, 256, H/8, W/8] - x = self.ps2(self.conv2(x)) # -> [B, 3, H, W] + x = self.ps1(self.act1(self.conv1(x))) # -> [B, 256, H/8, W/8] + x = self.ps2(self.conv2(x)) # -> [B, 3, H, W] return x @@ -544,11 +554,12 @@ def __init__(self): def forward(self, x): # x shape: [B, 4096, H/32, W/32] - x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16] - x = self.act2(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8] - x = self.conv3(self.ps3((x))) # -> [B, 3, H, W] + x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16] + x = self.act2(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8] + x = self.conv3(self.ps3((x))) # -> [B, 3, H, W] return x + class PatchDecoder_preps1(nn.Module): def __init__(self): super().__init__() @@ -566,10 +577,11 @@ def __init__(self): def forward(self, x): # x shape: [B, 4096, H/32, W/32] - x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16] - x = self.ps3(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8] + x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16] + x = self.ps3(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8] return x + class ConvDecoder(nn.Module): def __init__(self, input_dim=4096, hidden_dim=1024): super().__init__() diff --git a/lightllm/server/x2i_server/naive/modeling_neo_chat.py b/lightllm/server/x2i_server/naive/modeling_neo_chat.py index 326e2a574d..fcd0cdf5f6 100644 --- a/lightllm/server/x2i_server/naive/modeling_neo_chat.py +++ b/lightllm/server/x2i_server/naive/modeling_neo_chat.py @@ -18,19 +18,29 @@ from .configuration_neo_chat import NEOChatConfig from .modeling_neo_vit import NEOVisionModel from .modeling_qwen3 import Qwen3ForCausalLM, create_block_causal_mask -from .modeling_fm_modules import PositionEmbedding, TimestepEmbedder, FlowMatchingHead, RMSNorm, NerfEmbedder, SimpleMLPAdaLN, ConvDecoder +from .modeling_fm_modules import ( + PositionEmbedding, + TimestepEmbedder, + FlowMatchingHead, + RMSNorm, + NerfEmbedder, + SimpleMLPAdaLN, + ConvDecoder, +) logger = logging.get_logger(__name__) -def version_cmp(v1, v2, op='eq'): +def version_cmp(v1, v2, op="eq"): import operator from packaging import version + op_func = getattr(operator, op) return op_func(version.parse(v1), version.parse(v2)) + def prepare_flash_kv_cache( past_key_values, current_len: int, @@ -82,6 +92,7 @@ def prepare_flash_kv_cache( layer.flash_k_cache = k_cache layer.flash_v_cache = v_cache + def clear_flash_kv_cache(past_key_values): if past_key_values is None: return @@ -132,9 +143,10 @@ def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): # Generate intra-image patch index (row-major order) patch_id_within_image = torch.arange(N_total, device=device) - patch_id_within_image = patch_id_within_image - torch.cumsum( - torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0 - )[patch_to_sample] + patch_id_within_image = ( + patch_id_within_image + - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] + ) # Get H/W for each patch according to its image W_per_patch = W[patch_to_sample] @@ -146,8 +158,8 @@ def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): class NEOChatModel(PreTrainedModel): config_class = NEOChatConfig - main_input_name = 'pixel_values' - base_model_prefix = 'language_model' + main_input_name = "pixel_values" + base_model_prefix = "language_model" _supports_flash_attn_2 = True supports_gradient_checkpointing = True _no_split_modules = [ @@ -156,17 +168,17 @@ class NEOChatModel(PreTrainedModel): ] # support transformers 4.51.+ - _tp_plan = '' + _tp_plan = "" def __init__(self, config: NEOChatConfig, vision_model=None, language_model=None, use_flash_attn=True): super().__init__(config) - assert version_cmp(transformers.__version__, '4.37.0', 'ge') + assert version_cmp(transformers.__version__, "4.37.0", "ge") patch_size = config.vision_config.patch_size self.patch_size = patch_size self.template = config.template self.downsample_ratio = config.downsample_ratio - config.llm_config._attn_implementation = 'eager' + config.llm_config._attn_implementation = "eager" if vision_model is not None: self.vision_model = vision_model @@ -179,32 +191,33 @@ def __init__(self, config: NEOChatConfig, vision_model=None, language_model=None self.language_model = Qwen3ForCausalLM(config.llm_config) merge_size = int(1 / self.downsample_ratio) - output_dim = 3*(patch_size*merge_size)**2 + output_dim = 3 * (patch_size * merge_size) ** 2 llm_hidden_size = self.config.llm_config.hidden_size self.use_deep_fm_head = self.config.fm_head_layers > 2 self.use_pixel_head = self.config.use_pixel_head if self.use_deep_fm_head: - fm_head = FlowMatchingHead(llm_hidden_size, output_dim, dim=self.config.fm_head_dim, layers=self.config.fm_head_layers, mlp_ratio=self.config.fm_head_mlp_ratio) + fm_head = FlowMatchingHead( + llm_hidden_size, + output_dim, + dim=self.config.fm_head_dim, + layers=self.config.fm_head_layers, + mlp_ratio=self.config.fm_head_mlp_ratio, + ) else: fm_head = nn.Sequential( - nn.Linear(llm_hidden_size, 4096, bias=True), - nn.GELU(), - nn.Linear(4096, output_dim, bias=True), - ) + nn.Linear(llm_hidden_size, 4096, bias=True), + nn.GELU(), + nn.Linear(4096, output_dim, bias=True), + ) timestep_embedder = TimestepEmbedder(llm_hidden_size) self.fm_modules = nn.ModuleDict( - { - "vision_model_mot_gen": vision_model_mot_gen, - "timestep_embedder": timestep_embedder, - "fm_head": fm_head - } - ) + {"vision_model_mot_gen": vision_model_mot_gen, "timestep_embedder": timestep_embedder, "fm_head": fm_head} + ) if self.use_pixel_head: self.fm_modules["fm_head"] = ConvDecoder(llm_hidden_size) - self.concat_time_token_num = config.concat_time_token_num self.time_token_id = 151682 self.noise_scale = config.noise_scale @@ -222,27 +235,22 @@ def __init__(self, config: NEOChatConfig, vision_model=None, language_model=None if self.add_noise_scale_embedding: noise_scale_embedder = TimestepEmbedder(llm_hidden_size) - self.fm_modules['noise_scale_embedder'] = noise_scale_embedder - - + self.fm_modules["noise_scale_embedder"] = noise_scale_embedder self.img_context_token_id = None self.img_start_token_id = 151670 # self.conv_template = get_conv_template(self.template) # self.system_message = self.conv_template.system_message - def extract_feature(self, pixel_values, gen_model=False, grid_hw=None): if gen_model: - return self.fm_modules['vision_model_mot_gen'](pixel_values=pixel_values, - output_hidden_states=False, - return_dict=True, - grid_hw=grid_hw).last_hidden_state + return self.fm_modules["vision_model_mot_gen"]( + pixel_values=pixel_values, output_hidden_states=False, return_dict=True, grid_hw=grid_hw + ).last_hidden_state else: - return self.vision_model(pixel_values=pixel_values, - output_hidden_states=False, - return_dict=True, - grid_hw=grid_hw).last_hidden_state + return self.vision_model( + pixel_values=pixel_values, output_hidden_states=False, return_dict=True, grid_hw=grid_hw + ).last_hidden_state def patchify(self, images, patch_size, channel_first=False): """ @@ -253,11 +261,11 @@ def patchify(self, images, patch_size, channel_first=False): x = images.reshape(shape=(images.shape[0], 3, h, patch_size, w, patch_size)) if channel_first: - x = torch.einsum('nchpwq->nhwcpq', x) + x = torch.einsum("nchpwq->nhwcpq", x) else: - x = torch.einsum('nchpwq->nhwpqc', x) + x = torch.einsum("nchpwq->nhwpqc", x) - x = x.reshape(shape=(images.shape[0], h * w, patch_size**2 * 3)) + x = x.reshape(shape=(images.shape[0], h * w, patch_size ** 2 * 3)) return x def unpatchify(sle, x, patch_size, h=None, w=None): @@ -266,12 +274,12 @@ def unpatchify(sle, x, patch_size, h=None, w=None): images: (N, 3, H, W) """ if h is None or w is None: - h = w = int(x.shape[1]**.5) + h = w = int(x.shape[1] ** 0.5) else: h = h // patch_size w = w // patch_size x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, 3)) - x = torch.einsum('nhwpqc->nchpwq', x) + x = torch.einsum("nhwpqc->nchpwq", x) images = x.reshape(shape=(x.shape[0], 3, h * patch_size, w * patch_size)) return images @@ -316,15 +324,25 @@ def _build_t2i_image_indexes(self, token_h, token_w, text_len, device): w_image = idx % token_w return torch.stack([t_image, h_image, w_image], dim=0) - - - def _t2i_predict_v(self, input_embeds, indexes_image, attn_mask, past_key_values, t, z, - image_token_num, timestep_embeddings=None, image_size=None): + def _t2i_predict_v( + self, + input_embeds, + indexes_image, + attn_mask, + past_key_values, + t, + z, + image_token_num, + timestep_embeddings=None, + image_size=None, + ): B, L = z.shape[0], z.shape[1] outputs = self.language_model.model( inputs_embeds=input_embeds, - image_gen_indicators=torch.ones((input_embeds.shape[0], input_embeds.shape[1]), dtype=torch.bool, device=input_embeds.device), + image_gen_indicators=torch.ones( + (input_embeds.shape[0], input_embeds.shape[1]), dtype=torch.bool, device=input_embeds.device + ), indexes=indexes_image, attention_mask=attn_mask, past_key_values=past_key_values, @@ -341,44 +359,47 @@ def _t2i_predict_v(self, input_embeds, indexes_image, attn_mask, past_key_values img_2d = torch.einsum("b h w c -> b c h w", img_reshaped) img_2d = img_2d.contiguous().view(B, -1, token_h, token_w) - smoothed_img_2d = self.fm_modules['fm_head'](img_2d) + smoothed_img_2d = self.fm_modules["fm_head"](img_2d) - smoothed_reshaped = smoothed_img_2d.view(B, 3, token_h, self.patch_size * merge_size, token_w, self.patch_size * merge_size) + smoothed_reshaped = smoothed_img_2d.view( + B, 3, token_h, self.patch_size * merge_size, token_w, self.patch_size * merge_size + ) smoothed_reshaped = torch.einsum("b c h p w q -> b h w p q c", smoothed_reshaped) - out_1d = smoothed_reshaped.contiguous().view(B, L, self.patch_size * merge_size * self.patch_size * merge_size * 3) + out_1d = smoothed_reshaped.contiguous().view( + B, L, self.patch_size * merge_size * self.patch_size * merge_size * 3 + ) x_pred = out_1d else: if self.use_deep_fm_head: x_pred = self.fm_modules["fm_head"]( - outputs.last_hidden_state[:, -image_token_num:].view(B*L, -1), t.repeat(B*L) + outputs.last_hidden_state[:, -image_token_num:].view(B * L, -1), t.repeat(B * L) ).view(B, L, -1) else: x_pred = self.fm_modules["fm_head"]( outputs.last_hidden_state[:, -image_token_num:].view(B, L, -1) ).view(B, L, -1) - v_pred = (x_pred - z) / (1 - t).clamp_min(self.config.t_eps) return v_pred - @torch.no_grad() - def it2i_generate(self, - past_key_values_condition, - past_key_values_text_uncondition, - past_key_values_img_uncondition, - text_lens, - cfg_scale=1, - img_cfg_scale=1, - cfg_norm='none', - enable_timestep_shift=True, - timestep_shift=3, - image_size=(256, 256), - num_steps=30, - cfg_interval=(0.1, 1.0), - batch_size=1, - t_eps=0.02, - ): + def it2i_generate( + self, + past_key_values_condition, + past_key_values_text_uncondition, + past_key_values_img_uncondition, + text_lens, + cfg_scale=1, + img_cfg_scale=1, + cfg_norm="none", + enable_timestep_shift=True, + timestep_shift=3, + image_size=(256, 256), + num_steps=30, + cfg_interval=(0.1, 1.0), + batch_size=1, + t_eps=0.02, + ): self.config.t_eps = t_eps device, dtype = self.get_cache_device_dtype(past_key_values_condition) @@ -394,12 +415,24 @@ def it2i_generate(self, indexes_image_img_uncondition = self._build_t2i_image_indexes(token_h, token_w, S3, device=device) for layer_idx in range(len(past_key_values_condition.layers)): - past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:]) - past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[layer_idx].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:]) - past_key_values_text_uncondition.layers[layer_idx].keys = past_key_values_text_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].keys.shape[1:]) - past_key_values_text_uncondition.layers[layer_idx].values = past_key_values_text_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].values.shape[1:]) - past_key_values_img_uncondition.layers[layer_idx].keys = past_key_values_img_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].keys.shape[1:]) - past_key_values_img_uncondition.layers[layer_idx].values = past_key_values_img_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].values.shape[1:]) + past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand( + batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:] + ) + past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[ + layer_idx + ].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:]) + past_key_values_text_uncondition.layers[layer_idx].keys = past_key_values_text_uncondition.layers[ + layer_idx + ].keys.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].keys.shape[1:]) + past_key_values_text_uncondition.layers[layer_idx].values = past_key_values_text_uncondition.layers[ + layer_idx + ].values.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].values.shape[1:]) + past_key_values_img_uncondition.layers[layer_idx].keys = past_key_values_img_uncondition.layers[ + layer_idx + ].keys.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].keys.shape[1:]) + past_key_values_img_uncondition.layers[layer_idx].values = past_key_values_img_uncondition.layers[ + layer_idx + ].values.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].values.shape[1:]) prepare_flash_kv_cache( past_key_values_condition, @@ -417,31 +450,32 @@ def it2i_generate(self, batch_size=batch_size, ) - # init noise image tokens grid_h = image_size[1] // self.patch_size grid_w = image_size[0] // self.patch_size grid_hw = torch.tensor([[grid_h, grid_w]] * batch_size, device=device) noise_scale = self.noise_scale - if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'): - noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2) / self.noise_scale_base_image_seq_len) + if self.noise_scale_mode in ("resolution", "dynamic", "dynamic_sqrt"): + noise_scale = math.sqrt((grid_h * grid_w) / (merge_size ** 2) / self.noise_scale_base_image_seq_len) base = float(self.noise_scale_base_image_seq_len) - scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base) + scale = math.sqrt((grid_h * grid_w) / (merge_size ** 2) / base) noise_scale = scale * float(self.noise_scale) - if self.noise_scale_mode == 'dynamic_sqrt': + if self.noise_scale_mode == "dynamic_sqrt": noise_scale = math.sqrt(noise_scale) noise_scale = min(noise_scale, self.noise_scale_max_value) - image_prediction = noise_scale * torch.randn((batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype) + image_prediction = noise_scale * torch.randn( + (batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype + ) attention_mask_condition = {"full_attention": None} attention_mask_text_uncondition = {"full_attention": None} attention_mask_img_uncondition = {"full_attention": None} - timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device) + timesteps = torch.linspace(0.0, 1.0, num_steps + 1, device=device) if enable_timestep_shift: - timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift) + timesteps = self._apply_time_schedule(timesteps, token_h * token_w, timestep_shift) for step_i in range(num_steps): t = timesteps[step_i] @@ -450,16 +484,32 @@ def it2i_generate(self, z = self.patchify(image_prediction, self.patch_size * merge_size) image_input = self.patchify(image_prediction, self.patch_size, channel_first=True) - image_embeds = self.extract_feature(image_input.view(batch_size * grid_h*grid_w, -1), gen_model=True, grid_hw=grid_hw).view(batch_size, token_h*token_w, -1) - t_expanded = t.expand(batch_size*token_h*token_w) - timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(batch_size, token_h*token_w, -1) + image_embeds = self.extract_feature( + image_input.view(batch_size * grid_h * grid_w, -1), gen_model=True, grid_hw=grid_hw + ).view(batch_size, token_h * token_w, -1) + t_expanded = t.expand(batch_size * token_h * token_w) + timestep_embeddings = self.fm_modules["timestep_embedder"](t_expanded).view( + batch_size, token_h * token_w, -1 + ) if self.add_noise_scale_embedding: - noise_scale_tensor = torch.full_like(t_expanded, noise_scale/self.noise_scale_max_value) - noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(batch_size, token_h*token_w, -1) + noise_scale_tensor = torch.full_like(t_expanded, noise_scale / self.noise_scale_max_value) + noise_embeddings = self.fm_modules["noise_scale_embedder"](noise_scale_tensor).view( + batch_size, token_h * token_w, -1 + ) timestep_embeddings += noise_embeddings image_embeds = image_embeds + timestep_embeddings - v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size) + v_pred_condition = self._t2i_predict_v( + image_embeds, + indexes_image_condition, + attention_mask_condition, + past_key_values_condition, + t, + z, + image_token_num=token_h * token_w, + timestep_embeddings=timestep_embeddings, + image_size=image_size, + ) if not use_cfg: v_pred = v_pred_condition elif cfg_scale == 1 and img_cfg_scale == 1: @@ -489,7 +539,7 @@ def it2i_generate(self, timestep_embeddings=timestep_embeddings, image_size=image_size, ) - v_pred = out_uncond + cfg_scale *(v_pred_condition - out_uncond) + v_pred = out_uncond + cfg_scale * (v_pred_condition - out_uncond) else: out_img_cond = self._t2i_predict_v( image_embeds, @@ -520,18 +570,17 @@ def it2i_generate(self, ) if cfg_scale > 1 or img_cfg_scale > 1: - if cfg_norm == 'global': + if cfg_norm == "global": norm_v_condition = torch.norm(v_pred_condition, dim=(1, 2), keepdim=True) norm_v_cfg = torch.norm(v_pred, dim=(1, 2), keepdim=True) scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) v_pred = v_pred * scale - elif cfg_norm == 'channel': + elif cfg_norm == "channel": norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True) norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True) scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) v_pred = v_pred * scale - z = z + (t_next - t) * v_pred image_prediction = self.unpatchify(z, self.patch_size * merge_size, image_size[1], image_size[0]) @@ -542,7 +591,6 @@ def it2i_generate(self, return image_prediction - def get_cache_device_dtype(self, cache): """ Returns (device, dtype) of a DynamicCache. @@ -553,20 +601,22 @@ def get_cache_device_dtype(self, cache): raise ValueError("Cache is empty") @torch.no_grad() - def t2i_generate(self, - past_key_values_condition, - past_key_values_uncondition, - text_lens, - cfg_scale=1, - timestep_shift=3, - enable_timestep_shift=True, - cfg_norm='none', - image_size=(256, 256), - num_steps=30, - cfg_interval=(0.1, 1.0), - batch_size=1, - t_eps=0.02): - assert cfg_norm in ['cfg_zero_star', 'global', 'none'], f"cfg_norm={cfg_norm}" + def t2i_generate( + self, + past_key_values_condition, + past_key_values_uncondition, + text_lens, + cfg_scale=1, + timestep_shift=3, + enable_timestep_shift=True, + cfg_norm="none", + image_size=(256, 256), + num_steps=30, + cfg_interval=(0.1, 1.0), + batch_size=1, + t_eps=0.02, + ): + assert cfg_norm in ["cfg_zero_star", "global", "none"], f"cfg_norm={cfg_norm}" merge_size = int(1 / self.downsample_ratio) self.config.t_eps = t_eps @@ -580,10 +630,18 @@ def t2i_generate(self, indexes_image_uncondition = self._build_t2i_image_indexes(token_h, token_w, S2, device=device) for layer_idx in range(len(past_key_values_condition.layers)): - past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:]) - past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[layer_idx].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:]) - past_key_values_uncondition.layers[layer_idx].keys = past_key_values_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].keys.shape[1:]) - past_key_values_uncondition.layers[layer_idx].values = past_key_values_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].values.shape[1:]) + past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand( + batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:] + ) + past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[ + layer_idx + ].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:]) + past_key_values_uncondition.layers[layer_idx].keys = past_key_values_uncondition.layers[ + layer_idx + ].keys.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].keys.shape[1:]) + past_key_values_uncondition.layers[layer_idx].values = past_key_values_uncondition.layers[ + layer_idx + ].values.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].values.shape[1:]) # prepare flash cache once prepare_flash_kv_cache( @@ -600,27 +658,29 @@ def t2i_generate(self, # init noise image tokens grid_h = image_size[1] // self.patch_size grid_w = image_size[0] // self.patch_size - grid_hw = torch.tensor([[grid_h, grid_w]]*batch_size, device=device) + grid_hw = torch.tensor([[grid_h, grid_w]] * batch_size, device=device) noise_scale = self.noise_scale - if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'): - noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2) / self.noise_scale_base_image_seq_len) + if self.noise_scale_mode in ("resolution", "dynamic", "dynamic_sqrt"): + noise_scale = math.sqrt((grid_h * grid_w) / (merge_size ** 2) / self.noise_scale_base_image_seq_len) base = float(self.noise_scale_base_image_seq_len) - scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base) + scale = math.sqrt((grid_h * grid_w) / (merge_size ** 2) / base) noise_scale = scale * float(self.noise_scale) - if self.noise_scale_mode == 'dynamic_sqrt': + if self.noise_scale_mode == "dynamic_sqrt": noise_scale = math.sqrt(noise_scale) noise_scale = min(noise_scale, self.noise_scale_max_value) - image_prediction = noise_scale * torch.randn((batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype) + image_prediction = noise_scale * torch.randn( + (batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype + ) attention_mask_condition = {"full_attention": None} attention_mask_uncondition = {"full_attention": None} - timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device) + timesteps = torch.linspace(0.0, 1.0, num_steps + 1, device=device) if enable_timestep_shift: - timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift) + timesteps = self._apply_time_schedule(timesteps, token_h * token_w, timestep_shift) for step_i in range(num_steps): t = timesteps[step_i] @@ -628,43 +688,67 @@ def t2i_generate(self, z = self.patchify(image_prediction, self.patch_size * merge_size) image_input = self.patchify(image_prediction, self.patch_size, channel_first=True) - image_embeds = self.extract_feature(image_input.view(batch_size * grid_h*grid_w, -1), gen_model=True, grid_hw=grid_hw).view(batch_size, token_h*token_w, -1) - t_expanded = t.expand(batch_size*token_h*token_w) - timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(batch_size, token_h*token_w, -1) + image_embeds = self.extract_feature( + image_input.view(batch_size * grid_h * grid_w, -1), gen_model=True, grid_hw=grid_hw + ).view(batch_size, token_h * token_w, -1) + t_expanded = t.expand(batch_size * token_h * token_w) + timestep_embeddings = self.fm_modules["timestep_embedder"](t_expanded).view( + batch_size, token_h * token_w, -1 + ) if self.add_noise_scale_embedding: noise_scale_tensor = torch.full_like(t_expanded, noise_scale / self.noise_scale_max_value) - noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(batch_size, token_h*token_w, -1) + noise_embeddings = self.fm_modules["noise_scale_embedder"](noise_scale_tensor).view( + batch_size, token_h * token_w, -1 + ) timestep_embeddings += noise_embeddings image_embeds = image_embeds + timestep_embeddings - - v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w, - timestep_embeddings=timestep_embeddings, image_size=image_size) - + v_pred_condition = self._t2i_predict_v( + image_embeds, + indexes_image_condition, + attention_mask_condition, + past_key_values_condition, + t, + z, + image_token_num=token_h * token_w, + timestep_embeddings=timestep_embeddings, + image_size=image_size, + ) if t >= cfg_interval[0] and t <= cfg_interval[1] and cfg_scale > 1: - v_pred_uncondition = self._t2i_predict_v(image_embeds, indexes_image_uncondition, attention_mask_uncondition, past_key_values_uncondition, t, z, image_token_num=token_h*token_w, - timestep_embeddings=timestep_embeddings, image_size=image_size) - if cfg_norm == 'cfg_zero_star': + v_pred_uncondition = self._t2i_predict_v( + image_embeds, + indexes_image_uncondition, + attention_mask_uncondition, + past_key_values_uncondition, + t, + z, + image_token_num=token_h * token_w, + timestep_embeddings=timestep_embeddings, + image_size=image_size, + ) + if cfg_norm == "cfg_zero_star": positive_flat = v_pred_condition.view(batch_size, -1) negative_flat = v_pred_uncondition.view(batch_size, -1) - alpha = optimized_scale(positive_flat,negative_flat) + alpha = optimized_scale(positive_flat, negative_flat) alpha = alpha.view(batch_size, *([1] * (len(v_pred_condition.shape) - 1))) alpha = alpha.to(positive_flat.dtype) - if (step_i <= 0): - v_pred = v_pred_condition*0. + if step_i <= 0: + v_pred = v_pred_condition * 0.0 else: - v_pred = v_pred_uncondition * alpha + cfg_scale * (v_pred_condition - v_pred_uncondition * alpha) + v_pred = v_pred_uncondition * alpha + cfg_scale * ( + v_pred_condition - v_pred_uncondition * alpha + ) else: v_pred = v_pred_uncondition + cfg_scale * (v_pred_condition - v_pred_uncondition) - if cfg_norm == 'global': - norm_v_condition = torch.norm(v_pred_condition, dim=(1,2), keepdim=True) - norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True) + if cfg_norm == "global": + norm_v_condition = torch.norm(v_pred_condition, dim=(1, 2), keepdim=True) + norm_v_cfg = torch.norm(v_pred, dim=(1, 2), keepdim=True) scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) v_pred = v_pred * scale - elif cfg_norm == 'channel': + elif cfg_norm == "channel": norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True) norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True) scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0) @@ -682,7 +766,6 @@ def t2i_generate(self, return image_prediction - @property def lm_head(self): return self.language_model.get_output_embeddings() @@ -700,25 +783,29 @@ def set_output_embeddings(self, value): return self.language_model.set_output_embeddings(value) def get_thw_indexes(self, input_ids, grid_hw=None): - img_start_shift = torch.cat([torch.zeros(1, dtype=torch.long).to(input_ids.device), - (input_ids == self.img_start_token_id).long()], dim=0)[:-1] + img_start_shift = torch.cat( + [torch.zeros(1, dtype=torch.long).to(input_ids.device), (input_ids == self.img_start_token_id).long()], + dim=0, + )[:-1] not_img_token = (input_ids != self.img_context_token_id).long() - t_indexes = ((img_start_shift + not_img_token).cumsum(0) - 1) + t_indexes = (img_start_shift + not_img_token).cumsum(0) - 1 h_indexes = torch.zeros_like(t_indexes).to(t_indexes.device) w_indexes = torch.zeros_like(t_indexes).to(t_indexes.device) if grid_hw is not None: - selected = (input_ids == self.img_context_token_id) + selected = input_ids == self.img_context_token_id if selected.long().sum() > 0: abs_pos_w, abs_pos_h = build_abs_positions_from_grid_hw( - grid_hw // int(1 / self.downsample_ratio), device=t_indexes.device) + grid_hw // int(1 / self.downsample_ratio), device=t_indexes.device + ) h_indexes[selected] = abs_pos_h.to(t_indexes.device, t_indexes.dtype) w_indexes[selected] = abs_pos_w.to(t_indexes.device, t_indexes.dtype) return torch.stack([t_indexes, h_indexes, w_indexes], dim=0) NORM_MEAN = [0.5, 0.5, 0.5] -NORM_STD = [0.5, 0.5, 0.5] +NORM_STD = [0.5, 0.5, 0.5] + class NEOX2I: def __init__(self, model_path, device): @@ -736,7 +823,7 @@ def _denorm(self, x: torch.Tensor, mean=NORM_MEAN, std=NORM_STD): x: [B,3,H,W] normalized ((img-mean)/std). returns [0,1] clamped. """ mean = torch.tensor(mean, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) - std = torch.tensor(std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + std = torch.tensor(std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1) return (x * std + mean).clamp(0, 1) def _get_dynamic_cache(self, past_kv): @@ -748,15 +835,17 @@ def _get_dynamic_cache(self, past_kv): for layer_idx in range(L): k = past_kv[layer_idx][0].unsqueeze(0).to(self.device, non_blocking=True) v = past_kv[layer_idx][1].unsqueeze(0).to(self.device, non_blocking=True) - past_kv_dc.update(key_states=k, value_states=v, layer_idx=layer_idx,) + past_kv_dc.update( + key_states=k, + value_states=v, + layer_idx=layer_idx, + ) return past_kv_dc - def t2i(self, past_kv, past_kv_txt, param: X2IParams): past_kv_dc = self._get_dynamic_cache(past_kv) past_kv_txt_dc = self._get_dynamic_cache(past_kv_txt) - text_lens = (param.past_kvcache.get_compressed_len(), - param.past_kvcache_text.get_compressed_len()) + text_lens = (param.past_kvcache.get_compressed_len(), param.past_kvcache_text.get_compressed_len()) output = self.model.t2i_generate( past_key_values_condition=past_kv_dc, past_key_values_uncondition=past_kv_txt_dc, @@ -766,7 +855,8 @@ def t2i(self, past_kv, past_kv_txt, param: X2IParams): image_size=(param.width, param.height), num_steps=param.steps, batch_size=param.num_images, - timestep_shift=param.timestep_shift) + timestep_shift=param.timestep_shift, + ) return self._post_process(output) @@ -774,19 +864,18 @@ def _post_process(self, output): images = self._denorm(output) images = (images.clamp(0, 1) * 255.0).round().to(torch.uint8).cpu() - base64_images = [ - base64.b64encode(io.encode_jpeg(img).numpy()).decode("utf-8") - for img in images - ] + base64_images = [base64.b64encode(io.encode_jpeg(img).numpy()).decode("utf-8") for img in images] return base64_images def it2i(self, past_kv, past_kv_txt, past_kv_img, param: X2IParams): past_kv_dc = self._get_dynamic_cache(past_kv) past_kv_txt_dc = self._get_dynamic_cache(past_kv_txt) past_kv_img_dc = self._get_dynamic_cache(past_kv_img) - text_lens = (param.past_kvcache.get_compressed_len(), - param.past_kvcache_text.get_compressed_len(), - param.past_kvcache_img.get_compressed_len()) + text_lens = ( + param.past_kvcache.get_compressed_len(), + param.past_kvcache_text.get_compressed_len(), + param.past_kvcache_img.get_compressed_len(), + ) output = self.model.it2i_generate( past_key_values_condition=past_kv_dc, past_key_values_text_uncondition=past_kv_txt_dc, diff --git a/lightllm/server/x2i_server/naive/modeling_neo_vit.py b/lightllm/server/x2i_server/naive/modeling_neo_vit.py index 63ded83e7d..b38a2677e5 100644 --- a/lightllm/server/x2i_server/naive/modeling_neo_vit.py +++ b/lightllm/server/x2i_server/naive/modeling_neo_vit.py @@ -9,9 +9,7 @@ from .configuration_neo_vit import NEOVisionConfig -def precompute_rope_freqs_sincos( - dim: int, max_position: int, base: float = 10000.0, device=None -): +def precompute_rope_freqs_sincos(dim: int, max_position: int, base: float = 10000.0, device=None): """预计算 RoPE 的 cos 和 sin 值 (1D)。""" inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)) t = torch.arange(max_position, device=device).type_as(inv_freq) @@ -40,9 +38,10 @@ def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None): # Generate intra-image patch index (row-major order) patch_id_within_image = torch.arange(N_total, device=device) - patch_id_within_image = patch_id_within_image - torch.cumsum( - torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0 - )[patch_to_sample] + patch_id_within_image = ( + patch_id_within_image + - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample] + ) # Get H/W for each patch according to its image W_per_patch = W[patch_to_sample] @@ -63,8 +62,8 @@ def apply_rotary_emb_1d( # positions: (..., seq_len) # cos_cached: (max_pos, dim_part / 2) - cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) - sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) + cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2) + sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2) x1 = x[..., 0::2] x2 = x[..., 1::2] @@ -85,7 +84,7 @@ def apply_2d_rotary_pos_emb( cos_cached_y: torch.Tensor, sin_cached_y: torch.Tensor, abs_positions_x: torch.Tensor, - abs_positions_y: torch.Tensor + abs_positions_y: torch.Tensor, ): """应用2D RoPE到输入张量x。""" dim = x.shape[-1] @@ -97,13 +96,9 @@ def apply_2d_rotary_pos_emb( x_part_2 = x[..., dim_half:] # 将与 abs_positions_x 相关的旋转应用于 x_part_1 - rotated_part_1 = apply_rotary_emb_1d( - x_part_1, cos_cached_x, sin_cached_x, abs_positions_x - ) + rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x) # 将与 abs_positions_y 相关的旋转应用于 x_part_2 - rotated_part_2 = apply_rotary_emb_1d( - x_part_2, cos_cached_y, sin_cached_y, abs_positions_y - ) + rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y) # 将它们重新拼接起来。确保顺序与你分割时一致。 return torch.cat((rotated_part_1, rotated_part_2), dim=-1) @@ -123,10 +118,16 @@ def __init__(self, config: NEOVisionConfig): self.patch_size = config.patch_size self.patch_embedding = nn.Conv2d( - in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, ) self.dense_embedding = nn.Conv2d( - in_channels=self.embed_dim, out_channels=self.llm_embed_dim, kernel_size=self.downsample_factor, stride=self.downsample_factor + in_channels=self.embed_dim, + out_channels=self.llm_embed_dim, + kernel_size=self.downsample_factor, + stride=self.downsample_factor, ) self.gelu = nn.GELU() @@ -149,11 +150,13 @@ def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw): """ abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device) embeddings = apply_2d_rotary_pos_emb( - patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 - self.cos_cached_x, self.sin_cached_x, - self.cos_cached_y, self.sin_cached_y, + patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32 + self.cos_cached_x, + self.sin_cached_x, + self.cos_cached_y, + self.sin_cached_y, abs_pos_x, - abs_pos_y + abs_pos_y, ).to(self.patch_embedding.weight.dtype) return embeddings @@ -164,14 +167,14 @@ def forward(self, pixel_values: torch.FloatTensor, grid_hw=None) -> torch.Tensor 3, self.patch_size, self.patch_size, - ) # [28072, 768] -> [28072, 3, 16, 16] + ) # [28072, 768] -> [28072, 3, 16, 16] patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim) self.cos_cached_x = self.cos_cached_x.to(patch_embeds.device) self.sin_cached_x = self.sin_cached_x.to(patch_embeds.device) self.cos_cached_y = self.cos_cached_y.to(patch_embeds.device) self.sin_cached_y = self.sin_cached_y.to(patch_embeds.device) - patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) # [28072, 1024] - assert (grid_hw[:,0] * grid_hw[:,1]).sum() == patch_embeds.shape[0] + patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) # [28072, 1024] + assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[0] patches_list = [] cur_position = 0 @@ -186,18 +189,18 @@ def forward(self, pixel_values: torch.FloatTensor, grid_hw=None) -> torch.Tensor embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C) assert cur_position == patch_embeds.shape[0] - assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor**2) + assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2) return embeddings class NEOVisionModel(PreTrainedModel): - main_input_name = 'pixel_values' + main_input_name = "pixel_values" _supports_flash_attn_2 = True supports_gradient_checkpointing = True config_class = NEOVisionConfig # support transformers 4.51.+ - _tp_plan = '' + _tp_plan = "" def __init__(self, config: NEOVisionConfig): super().__init__(config) @@ -206,12 +209,12 @@ def __init__(self, config: NEOVisionConfig): self.embeddings = NEOVisionEmbeddings(config) def forward( - self, - pixel_values: Optional[torch.FloatTensor] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_embeds: Optional[torch.FloatTensor] = None, - grid_hw: Optional[torch.Tensor] = None + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_embeds: Optional[torch.FloatTensor] = None, + grid_hw: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -219,7 +222,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None and pixel_embeds is None: - raise ValueError('You have to specify pixel_values or pixel_embeds') + raise ValueError("You have to specify pixel_values or pixel_embeds") if pixel_embeds is not None: hidden_states = pixel_embeds diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py index 4abfde5a83..ec2d9d1760 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py @@ -148,9 +148,7 @@ def _build_inputs( n_spans = int(torch.randint(1, num_image_spans_max + 1, (1,), generator=g).item()) start_pack = int(b_start_loc[i].item()) for _ in range(n_spans): - span_len = int( - torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item() - ) + span_len = int(torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item()) span_len = min(span_len, M) s_rel = int(torch.randint(0, M - span_len + 1, (1,), generator=g).item()) b_image_token_tag[start_pack + s_rel : start_pack + s_rel + span_len] = True diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py index 9a635f0445..41f6b476ba 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py @@ -172,9 +172,7 @@ def _build_inputs( n_spans = int(torch.randint(1, num_image_spans_max + 1, (1,), generator=g).item()) start_pack = int(b_q_start_loc[i].item()) for _ in range(n_spans): - span_len = int( - torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item() - ) + span_len = int(torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item()) span_len = min(span_len, M) s_rel = int(torch.randint(0, M - span_len + 1, (1,), generator=g).item()) b_image_token_tag[start_pack + s_rel : start_pack + s_rel + span_len] = True @@ -214,9 +212,9 @@ def _fa3_prefill_with_image_tag(inputs: dict) -> torch.Tensor: device = q.device # Build page_table[b, p] = req_to_token_indexs[b_req_idx[b], p]. - page_table = inputs["req_to_token_indexs"][ - inputs["b_req_idx"].long(), : inputs["max_seq_len_in_batch"] - ].to(torch.int32) + page_table = inputs["req_to_token_indexs"][inputs["b_req_idx"].long(), : inputs["max_seq_len_in_batch"]].to( + torch.int32 + ) q_seq_lens_t = inputs["b_seq_len"].to(torch.int32) - inputs["b_prompt_cache_len"].to(torch.int32) @@ -366,9 +364,7 @@ def _run_case( (3, 8, 2, 128, torch.bfloat16, 6, 8, 1024), ], ) -def test_fa3_neo_prefill_with_image_tag( - batch, Hq, Hk, D, dtype, seed, max_q_seq_len, max_prompt_cache_len -): +def test_fa3_neo_prefill_with_image_tag(batch, Hq, Hk, D, dtype, seed, max_q_seq_len, max_prompt_cache_len): abs_err, rel_err, cos = _run_case( batch=batch, Hq=Hq, @@ -426,9 +422,7 @@ def fa3_run(): o_triton = torch.empty_like(inputs["q"]) # Kernel signature requires position_ids but the current masking path does # not read it; zeros are fine for perf measurement. - position_ids_0 = torch.zeros( - inputs["q"].shape[0], dtype=torch.int32, device=inputs["q"].device - ) + position_ids_0 = torch.zeros(inputs["q"].shape[0], dtype=torch.int32, device=inputs["q"].device) def triton_run(): context_attention_fwd_neo( diff --git a/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py b/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py index ca6c5043c9..2bcced470b 100644 --- a/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py +++ b/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py @@ -5,10 +5,7 @@ # ========================================================= # GPU guard -pytestmark = pytest.mark.skipif( - not torch.cuda.is_available(), - reason="CUDA not available" -) +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") # ========================================================= # 工具函数:生成 GPU KV cache def make_kv(L, T, H, D, device="cuda"): @@ -20,14 +17,14 @@ def make_kv(L, T, H, D, device="cuda"): # Test 1: 基础功能(每 token 对应一个 page slot) def test_basic_copy(): L, T, H, D = 2, 8, 4, 16 - B = 1 # token_block_size - P = 4 # all_page_num + B = 1 # token_block_size + P = 4 # all_page_num gpu_kv = make_kv(L, T, H, D) cpu_kv = torch.zeros((P, L, B, H, D), dtype=torch.float32, pin_memory=True) token_indexes = torch.tensor([1, 3, 5, 7], device="cuda") - page_indexes = torch.tensor([0, 1, 2, 3], device="cuda") + page_indexes = torch.tensor([0, 1, 2, 3], device="cuda") offload_gpu_kv_to_cpu_all( token_indexes, @@ -43,10 +40,10 @@ def test_basic_copy(): for i in range(len(token_indexes)): token = token_indexes[i].item() - page = page_indexes[i].item() + page = page_indexes[i].item() expected = gpu_kv[:, token, :, :].cpu() # (L,H,D) - actual = cpu_kv[page, :, 0, :, :] # (L,H,D) + actual = cpu_kv[page, :, 0, :, :] # (L,H,D) assert torch.allclose(expected, actual) @@ -62,7 +59,7 @@ def test_random_tokens(): cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True) token_indexes = torch.tensor([10, 2, 7, 15, 0, 3], device="cuda") - page_indexes = torch.arange(6, device="cuda") + page_indexes = torch.arange(6, device="cuda") offload_gpu_kv_to_cpu_all( token_indexes, @@ -80,10 +77,7 @@ def test_random_tokens(): t = token_indexes[i].item() p = page_indexes[i].item() - assert torch.allclose( - cpu_kv[p, :, 0, :, :], - gpu_kv[:, t, :, :].cpu() - ) + assert torch.allclose(cpu_kv[p, :, 0, :, :], gpu_kv[:, t, :, :].cpu()) # ========================================================= @@ -100,7 +94,7 @@ def test_with_scale(): cpu_scale = torch.zeros((P, L, B, H, D // 8), pin_memory=True) token_indexes = torch.tensor([1, 2, 3], device="cuda") - page_indexes = torch.tensor([0, 1, 2], device="cuda") + page_indexes = torch.tensor([0, 1, 2], device="cuda") offload_gpu_kv_to_cpu_all( token_indexes, @@ -119,16 +113,10 @@ def test_with_scale(): p = page_indexes[i].item() # KV - assert torch.allclose( - cpu_kv[p, :, 0, :, :], - gpu_kv[:, t, :, :].cpu() - ) + assert torch.allclose(cpu_kv[p, :, 0, :, :], gpu_kv[:, t, :, :].cpu()) # scale - assert torch.allclose( - cpu_scale[p, :, 0, :], - gpu_scale[:, t, :].cpu() - ) + assert torch.allclose(cpu_scale[p, :, 0, :], gpu_scale[:, t, :].cpu()) # ========================================================= @@ -144,11 +132,14 @@ def test_tp_split(): cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True) token_indexes = torch.tensor([1, 2], device="cuda") - page_indexes = torch.tensor([0, 1], device="cuda") + page_indexes = torch.tensor([0, 1], device="cuda") offload_gpu_kv_to_cpu_all( - token_indexes, gpu_kv, None, - cpu_kv, None, + token_indexes, + gpu_kv, + None, + cpu_kv, + None, page_indexes, tp_index=0, tp_world_size=tp_world_size, @@ -156,8 +147,11 @@ def test_tp_split(): ) offload_gpu_kv_to_cpu_all( - token_indexes, gpu_kv, None, - cpu_kv, None, + token_indexes, + gpu_kv, + None, + cpu_kv, + None, page_indexes, tp_index=1, tp_world_size=tp_world_size, @@ -170,15 +164,9 @@ def test_tp_split(): t = token_indexes[i].item() p = page_indexes[i].item() - assert torch.allclose( - cpu_kv[p, :, 0, :split, :], - gpu_kv[:, t, :split, :].cpu() - ) + assert torch.allclose(cpu_kv[p, :, 0, :split, :], gpu_kv[:, t, :split, :].cpu()) - assert torch.allclose( - cpu_kv[p, :, 0, split:, :], - gpu_kv[:, t, split:, :].cpu() - ) + assert torch.allclose(cpu_kv[p, :, 0, split:, :], gpu_kv[:, t, split:, :].cpu()) # ========================================================= @@ -192,7 +180,7 @@ def test_empty(): cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True) token_indexes = torch.tensor([], dtype=torch.long, device="cuda") - page_indexes = torch.tensor([], dtype=torch.long, device="cuda") + page_indexes = torch.tensor([], dtype=torch.long, device="cuda") offload_gpu_kv_to_cpu_all( token_indexes, @@ -208,5 +196,6 @@ def test_empty(): assert torch.all(cpu_kv == 0) + if __name__ == "__main__": - pytest.main() \ No newline at end of file + pytest.main() From 2f9489c3816690d746a2e35ab2919fbe371d4ae2 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Wed, 6 May 2026 14:59:11 +0800 Subject: [PATCH 39/41] add seed rng session --- lightllm/server/api_openai.py | 2 + lightllm/server/core/objs/x2i_params.py | 6 ++ lightllm/server/httpserver/manager.py | 4 + .../server/x2i_server/lightx2v/adapter.py | 26 ++++- lightllm/server/x2i_server/manager.py | 29 +++++- lightllm/server/x2i_server/rng_state_cache.py | 96 +++++++++++++++++++ 6 files changed, 160 insertions(+), 3 deletions(-) create mode 100644 lightllm/server/x2i_server/rng_state_cache.py diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index d137bba162..5f28827155 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -634,6 +634,7 @@ async def _get_text_generator_input(request: ChatCompletionRequest): "n": request.n, "best_of": request.n, "add_special_tokens": False, + "seed": request.seed, } # Structured output handling if request.response_format: @@ -771,6 +772,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request x2i_params = X2IParams() x2i_params.init_from_image_config(request.image_config) + sampling_params.seed = x2i_params.seed enable_thinking = request.chat_template_kwargs.get("enable_thinking", False) print(f"x2i_params: {x2i_params} {image_only} {enable_thinking}", flush=True) diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py index cac7b9abdd..9463ffc01f 100644 --- a/lightllm/server/core/objs/x2i_params.py +++ b/lightllm/server/core/objs/x2i_params.py @@ -45,6 +45,11 @@ class X2IParams(ctypes.Structure): ("past_kvcache_img", PastKVCachePageList), ("total_prompt_tokens", ctypes.c_int), ("request_id", ctypes.c_int64), + # session_id 用于在 x2i server 端按聊天会话缓存 RNG state, + # 解决多并发下不同 session 互相覆盖全局 torch/cuda RNG 的问题。 + # 由 httpserver 在该 session 第一次生成图时绑定为该次的 request_id, + # 之后整段图文交错过程中保持不变(first_image=False 时不再覆盖)。 + ("session_id", ctypes.c_int64), ] _width: int = 1024 @@ -80,6 +85,7 @@ def _get(key, default): self.dynamic_resolution = _get("dynamic_resolution", X2IParams._dynamic_resolution) self.total_prompt_tokens = 0 self.request_id = 0 + self.session_id = 0 self.has_updated_hw = False self.first_image = True diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 653117a071..9635965dd2 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -552,6 +552,10 @@ async def generation_wrapper(prompt, sample, multimodal, request): # use the first request id as the gen image request id x2i_req_id = generate_req_ids[0] generation_params.request_id = x2i_req_id + # 第一次生成时,把本次 request_id 固化为 session_id, + # 后续同一会话的所有图都用同一个 session_id,供 x2i 端按会话缓存 RNG state。 + if generation_params.first_image: + generation_params.session_id = x2i_req_id req_status = X2IReqStatus(generation_params, generate_req_ids) self.req_id_to_x2i_reqs[generation_params.request_id] = req_status diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py index 95a795f617..60d04bb6d8 100644 --- a/lightllm/server/x2i_server/lightx2v/adapter.py +++ b/lightllm/server/x2i_server/lightx2v/adapter.py @@ -15,6 +15,7 @@ from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease, CfgNormType from ..past_kv_cache_client import PastKVCacheClient +from ..rng_state_cache import RngStateCache logger = init_logger(__name__) @@ -41,6 +42,11 @@ def __init__(self, args: StartArgs, rank: int, world_size: int): self.task_dist_group = dist.new_group(backend="gloo", timeout=datetime.timedelta(days=30)) + # Per-chat-session RNG snapshot, see lightllm/server/x2i_server/rng_state_cache.py. + # 各 worker 各自维护一份;由于本进程是单 stream 的串行 _process,且任务通过 broadcast + # 同步,因此各 rank 之间的 RNG 演进保持一致。 + self.rng_state_cache = RngStateCache() + def _init_pipeline(self): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(self.args.x2i_worker_nccl_port) @@ -123,14 +129,30 @@ async def _process(self, param: X2IParams): past_kv_cache_text, past_kv_cache_img, ) - seed = param.seed if param.first_image else None - logger.info(f"seed: {seed} {param.seed} first_image: {param.first_image}") + + session_id = param.session_id + if param.first_image: + # 本 session 第一张图:用传入的 seed 初始化全局 RNG + seed = param.seed + else: + # 后续图:恢复本 session 上次结束时的 RNG,避免被其他 session 的 seed_all 污染 + restored = self.rng_state_cache.restore(session_id) + seed = None + if not restored: + logger.warning( + f"session {session_id} rng state miss (maybe expired or first call after restart), " + f"fallback to current global rng" + ) + logger.info(f"seed: {seed} param.seed: {param.seed} first_image: {param.first_image} session_id: {session_id}") image = self.pipe.generate( seed=seed, save_result_path="", target_shape=[param.height, param.width], ) + # 保存当前 RNG state,供同一 session 下一张图使用 + self.rng_state_cache.save(session_id) + return [image] diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py index c48f283e79..0d169d3673 100644 --- a/lightllm/server/x2i_server/manager.py +++ b/lightllm/server/x2i_server/manager.py @@ -20,6 +20,7 @@ from lightllm.utils.dist_utils import set_current_device_id from lightllm.utils.start_utils import start_submodule_processes from .past_kv_cache_client import PastKVCacheClient +from .rng_state_cache import RngStateCache logger = init_logger(__name__) @@ -62,6 +63,10 @@ def __init__( self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True) + # Per-chat-session RNG snapshot, so concurrent sessions don't clobber each other's + # global torch / cuda RNG state between successive image generations. + self.rng_state_cache = RngStateCache() + async def wait_to_model_ready(self): if self.world_size <= 1: @@ -91,6 +96,22 @@ async def wait_to_model_ready(self): args = [(self.args, rank, self.world_size) for rank in range(self.world_size)] start_submodule_processes(funcs, args) + def _prepare_seed_for_iter(self, param: X2IParams, i: int, session_id: int): + """决定本轮 generate() 调用要传给 pipeline 的 seed。 + + - 一个 chat session 的第一张图(first_image=True 且 i==0):传种子,由 pipeline 内部 seed_all。 + - 后续图(first_image=False 且 i==0):先把该 session 上次保存的 RNG state 恢复回来, + 再传 seed=None,让生成接着上一张图的 RNG 继续走。这样别的 session 中间穿插的 seed_all + 不会污染本 session 的 RNG。 + - 同一次调用内部 i>0 的图:天然顺延全局 RNG 即可,不重置也不恢复。 + """ + if i == 0: + if param.first_image: + return param.seed + self.rng_state_cache.restore(session_id) + return None + return None + async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams): if self.use_naive_x2i: images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param) @@ -105,14 +126,18 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams timestep_shift=param.timestep_shift, ) images = [] + session_id = param.session_id for i in range(param.num_images): self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text) + seed_arg = self._prepare_seed_for_iter(param, i, session_id) image = self.gen_pipe.generate( - seed=param.seed if param.first_image else None, + seed=seed_arg, save_result_path="", # 返回base64,不需要指定路径了 target_shape=[param.height, param.width], # Height, Width ) images.append(image) + # 保存当前 RNG state,下次同一 session 的图继续从这里走 + self.rng_state_cache.save(session_id) return images async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams): @@ -131,6 +156,8 @@ async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_i images = [] for i in range(param.num_images): self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img) + # it2i 的每张图都显式带种子(与 img_len+i 绑定),每次都会触发 seed_all, + # 不依赖也不会污染 t2i 的 session RNG state,无需缓存/恢复。 image = self.gen_pipe.generate( seed=param.seed + param.past_kvcache_img.img_len + i, save_result_path="", # 返回base64,不需要指定路径了 diff --git a/lightllm/server/x2i_server/rng_state_cache.py b/lightllm/server/x2i_server/rng_state_cache.py new file mode 100644 index 0000000000..54a47f2e0d --- /dev/null +++ b/lightllm/server/x2i_server/rng_state_cache.py @@ -0,0 +1,96 @@ +"""Per-session RNG state cache for the x2i server. + +In interleaved image-text scenarios, only the very first image of a chat +session needs an explicit seed; subsequent images should continue from the +RNG state left by the previous image generation. With multiple chat sessions +running concurrently, the global torch / cuda / numpy / random RNG state is +shared across sessions, so seeding for session B's first image would corrupt +session A's continuation. + +This cache snapshots the RNG state after each generation per session_id and +restores it before the next generation of the same session, so that each +session sees a private, deterministic RNG stream. +""" + +import random +import time +from typing import Dict, Tuple + +import numpy as np +import torch + +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def capture_rng_state() -> Dict: + state = { + "random": random.getstate(), + "numpy": np.random.get_state(), + "torch": torch.get_rng_state(), + } + if torch.cuda.is_available(): + state["torch_cuda"] = torch.cuda.get_rng_state_all() + return state + + +def restore_rng_state(state: Dict) -> None: + random.setstate(state["random"]) + np.random.set_state(state["numpy"]) + torch.set_rng_state(state["torch"]) + cuda_state = state.get("torch_cuda") + if cuda_state is not None and torch.cuda.is_available(): + torch.cuda.set_rng_state_all(cuda_state) + + +class RngStateCache: + """LRU + TTL cache for RNG snapshots keyed by chat session id. + + - `save(session_id)`: snapshot current global RNG state for the session. + - `restore(session_id)`: restore previously saved state, returns True on hit. + - `discard(session_id)`: drop the session entry if any. + - Entries expire after `ttl_seconds` of inactivity, and the cache is hard + capped at `max_size` (oldest evicted first) to bound memory. + """ + + def __init__(self, max_size: int = 1024, ttl_seconds: float = 600.0): + self._max_size = max_size + self._ttl = ttl_seconds + # session_id -> (last_used_ts, state_dict) + self._cache: Dict[int, Tuple[float, Dict]] = {} + + def save(self, session_id: int) -> None: + if session_id == 0: + return + self._cache[session_id] = (time.time(), capture_rng_state()) + self._evict() + + def restore(self, session_id: int) -> bool: + if session_id == 0: + return False + item = self._cache.get(session_id) + if item is None: + return False + ts, state = item + if self._ttl > 0 and time.time() - ts > self._ttl: + self._cache.pop(session_id, None) + logger.warning(f"RNG state for session {session_id} expired (TTL={self._ttl}s)") + return False + restore_rng_state(state) + self._cache[session_id] = (time.time(), state) + return True + + def discard(self, session_id: int) -> None: + self._cache.pop(session_id, None) + + def _evict(self) -> None: + if self._ttl > 0: + now = time.time() + expired = [sid for sid, (ts, _) in self._cache.items() if now - ts > self._ttl] + for sid in expired: + self._cache.pop(sid, None) + if len(self._cache) > self._max_size: + sorted_items = sorted(self._cache.items(), key=lambda kv: kv[1][0]) + for sid, _ in sorted_items[: len(self._cache) - self._max_size]: + self._cache.pop(sid, None) From 66b420c59e6d0eca12ef0c051cb40daa9f11bdc5 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 8 May 2026 08:34:31 +0000 Subject: [PATCH 40/41] Fix Neo image-token attention scope in Triton prefill --- .../layer_infer/transformer_layer_infer.py | 2 +- lightllm/models/neo_chat_moe/infer_struct.py | 4 +- .../layer_infer/transformer_layer_infer.py | 2 +- .../context_attention_fwd_neo.py | 49 +++++++++--- .../triton_kernel/get_neo_position.py | 17 ++-- .../test_context_attention_fwd_neo.py | 79 +++++++++++-------- 6 files changed, 99 insertions(+), 54 deletions(-) diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index c5c4f3c343..e9623b99a6 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -113,7 +113,7 @@ def _context_attention_kernel( infer_state.b_ready_cache_len, infer_state.max_q_seq_len, infer_state.req_manager.req_to_token_indexs, - infer_state.b_image_token_tag, + infer_state.b_image_token_end, ) return o_tensor diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py index 1693bcb964..c52147cd9b 100644 --- a/lightllm/models/neo_chat_moe/infer_struct.py +++ b/lightllm/models/neo_chat_moe/infer_struct.py @@ -20,7 +20,7 @@ def __init__(self): def init_some_extra_state(self, model: LlamaTpPartModel): LlamaInferStateInfo.init_some_extra_state(self, model) if self.is_prefill: - self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda( + self.b_image_token_end = torch.zeros([self.position_ids.size(0)], dtype=torch.int32, device="cpu").cuda( non_blocking=True ) self.position_ids = self.get_neo_position(self.multimodal_params) @@ -96,6 +96,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor: b_ready_cache_len=self.b_ready_cache_len, b_q_seq_len=self.b_q_seq_len, b_start_loc=self.b_q_start_loc, - b_image_token_tag=self.b_image_token_tag, + b_image_token_end=self.b_image_token_end, ) return position_ids diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 66a320fe02..271c9592c3 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -111,7 +111,7 @@ def _context_attention_kernel( infer_state.b_ready_cache_len, infer_state.max_q_seq_len, infer_state.req_manager.req_to_token_indexs, - infer_state.b_image_token_tag, + infer_state.b_image_token_end, ) return o_tensor diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py index c9795e113a..8cefe486c0 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -34,7 +34,7 @@ def _fwd_kernel( stride_req_to_tokens_s, kv_group_num, b_prompt_cache_len, - b_image_token_tag, + b_image_token_end, H: tl.constexpr, QK_HEAD_DIM: tl.constexpr, V_HEAD_DIM: tl.constexpr, @@ -79,18 +79,25 @@ def _fwd_kernel( acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32) # absolute q positions in the request q_pos = prompt_cache_len + offs_m # [M] - q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False) + q_gid = tl.load(position_ids + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=-1) + q_image_end = tl.load(b_image_token_end + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=0) - # per-M-block: only scan full K range if this block has image tokens - has_image = tl.reduce_or(q_image_token_tag.to(tl.int32), axis=0) > 0 causal_end = tl.minimum(prompt_cache_len + block_start_loc + BLOCK_M, total_len) - block_end_loc = tl.where(has_image, total_len, causal_end) + block_image_end = tl.minimum(tl.max(q_image_end, axis=0), total_len) + block_end_loc = tl.maximum(causal_end, block_image_end) for start_n in range(0, block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) k_pos = start_n + offs_n # [N] k_valid = k_pos < block_end_loc + k_in_new = k_pos >= prompt_cache_len + k_rel = k_pos - prompt_cache_len + k_gid = tl.load( + position_ids + cur_batch_in_all_start_index + k_rel, + mask=k_valid & k_in_new, + other=-2, + ) # map logical pos -> mem_index (for K/V) kv_loc = tl.load( @@ -105,8 +112,14 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) - # mask: causal OR same gid (only possible inside NEW part) - mask = ((q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None]) & k_valid[None, :] + causal_mask = q_pos[:, None] >= k_pos[None, :] + image_mask = ( + (q_image_end[:, None] > 0) + & k_in_new[None, :] + & (k_pos[None, :] < q_image_end[:, None]) + & (q_gid[:, None] == k_gid[None, :]) + ) + mask = (causal_mask | image_mask) & k_valid[None, :] qk = tl.where(mask, qk * sm_scale, -1.0e8) # online softmax @@ -151,7 +164,7 @@ def context_attention_fwd_neo( b_prompt_cache_len, max_input_len, req_to_token_indexs, - b_image_token_tag, + b_image_token_end, ): # minimal safety: position_ids must cover packed q rows assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) @@ -199,7 +212,7 @@ def context_attention_fwd_neo( req_to_token_indexs.stride(1), kv_group_num=kv_group_num, b_prompt_cache_len=b_prompt_cache_len, - b_image_token_tag=b_image_token_tag, + b_image_token_end=b_image_token_end, H=head, QK_HEAD_DIM=Lk, V_HEAD_DIM=Lk, @@ -215,6 +228,7 @@ def reference_attention( k, v, position_ids_q, # 1D packed like q (only NEW tokens) + b_image_token_end, b_req_idx, b_start_loc, b_seq_len, @@ -240,6 +254,7 @@ def reference_attention( q_start = int(b_start_loc[b].item()) q_blk = q[q_start : q_start + new_len] # [M, Hq, D] gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] + image_end_new = b_image_token_end[q_start : q_start + new_len].to(torch.int64) # [M] # gather K/V for full request by logical pos -> mem_index token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L] @@ -267,7 +282,12 @@ def reference_attention( k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new k_gid[k_in_new] = gid_new[k_rel[k_in_new]] - allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) + allow = allow | ( + (image_end_new[:, None] > 0) + & k_in_new[None, :] + & (k_pos[None, :] < image_end_new[:, None]) + & (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) + ) # scores: [Hq, M, L] q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] @@ -337,17 +357,20 @@ def make_test_case( # position_ids_q: only NEW tokens, packed like q position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) + b_image_token_end = torch.zeros((sum_q,), device=device, dtype=torch.int32) for b in range(batch): M = int(new_lens[b].item()) + P = int(prompt_lens[b].item()) start = int(b_start_loc[b].item()) - gid = torch.arange(M, device=device, dtype=torch.int32) + gid = torch.arange(P, P + M, device=device, dtype=torch.int32) # make one repeated block inside NEW part to simulate image tokens if M >= 4 and torch.rand((), device=device).item() > 0.3: s = int(torch.randint(0, M - 2, (1,), device=device).item()) e = min(M, s + 3) gid[s:e] = gid[s] + b_image_token_end[start + s : start + e] = P + e position_ids_q[start : start + M] = gid @@ -362,6 +385,7 @@ def make_test_case( v, o, position_ids_q, + b_image_token_end, b_req_idx, b_start_loc, b_seq_len, @@ -378,6 +402,7 @@ def check_once(device="cuda", dtype=torch.float16, seed=0): v, o, position_ids_q, + b_image_token_end, b_req_idx, b_start_loc, b_seq_len, @@ -398,6 +423,7 @@ def check_once(device="cuda", dtype=torch.float16, seed=0): b_prompt_cache_len, max_new_len, req_to_token_indexs, + b_image_token_end, ) ref = reference_attention( @@ -405,6 +431,7 @@ def check_once(device="cuda", dtype=torch.float16, seed=0): k, v, position_ids_q, + b_image_token_end, b_req_idx, b_start_loc, b_seq_len, diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py index dc57870bf1..b2e1af9f8f 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py @@ -16,7 +16,7 @@ def _get_neo_position_triton( b_ready_cache_len: torch.Tensor, b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, - b_image_token_tag: torch.Tensor, + b_image_token_end: torch.Tensor, BLOCK_SIZE: tl.constexpr, ) -> torch.Tensor: cur_batch = tl.program_id(0) @@ -30,6 +30,7 @@ def _get_neo_position_triton( local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i) image_start_idx = start_loc + local_image_start_idx - cache_len image_len = tl.load(b_image_len + image_start_num + i) + image_end = local_image_start_idx + image_len # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1) image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2) @@ -40,8 +41,8 @@ def _get_neo_position_triton( h_pos = off // image_w w_pos = off % image_w tl.store( - b_image_token_tag + off + image_start_idx, - True, + b_image_token_end + off + image_start_idx, + image_end, mask=(off < image_len) & (off + local_image_start_idx - cache_len < q_seq_len) & (local_image_start_idx - cache_len + off >= 0), @@ -97,7 +98,7 @@ def get_neo_position_triton( b_ready_cache_len: torch.Tensor, b_q_seq_len: torch.Tensor, b_start_loc: torch.Tensor, - b_image_token_tag: torch.Tensor, + b_image_token_end: torch.Tensor, ) -> torch.Tensor: batch_size = b_q_seq_len.shape[0] @@ -116,7 +117,7 @@ def get_neo_position_triton( b_ready_cache_len=b_ready_cache_len, b_q_seq_len=b_q_seq_len, b_start_loc=b_start_loc, - b_image_token_tag=b_image_token_tag, + b_image_token_end=b_image_token_end, BLOCK_SIZE=BLOCK_SIZE, ) @@ -133,7 +134,7 @@ def test(): .expand(3, -1) .contiguous() ) - b_image_token_tag = torch.zeros([position_ids.size(1)], dtype=torch.bool, device="cuda") + b_image_token_end = torch.zeros([position_ids.size(1)], dtype=torch.int32, device="cuda") position_ids[1:].zero_() b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda") b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda") @@ -148,10 +149,10 @@ def test(): b_ready_cache_len, b_q_seq_len, b_start_loc, - b_image_token_tag, + b_image_token_end, ) - print(b_image_token_tag) + print(b_image_token_end) print(position_ids) # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1) diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py index ec2d9d1760..730d4263ee 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py @@ -5,11 +5,11 @@ batch element we gather K/V for the whole request (prompt + new tokens) via ``req_to_token_indexs`` and apply:: - allow[m, k] = (k <= q_pos[m]) OR image_tag[m] for k in [0, total) + allow[m, k] = (k <= q_pos[m]) OR same_image_group[m, k] -i.e. normal queries are causal, image-token queries can see every real key in -the request. If the Triton kernel disagrees with this reference, the kernel is -wrong. +i.e. normal queries are causal, and image-token queries can only see future +tokens from the same image span. If the Triton kernel disagrees with this +reference, the kernel is wrong. Run directly for quick debugging: @@ -35,12 +35,13 @@ def torch_reference_context_attention_neo( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + position_ids_0: torch.Tensor, b_req_idx: torch.Tensor, b_start_loc: torch.Tensor, b_seq_len: torch.Tensor, b_prompt_cache_len: torch.Tensor, req_to_token_indexs: torch.Tensor, - b_image_token_tag: torch.Tensor, + b_image_token_end: torch.Tensor, ) -> torch.Tensor: device = q.device dtype = q.dtype @@ -61,7 +62,8 @@ def torch_reference_context_attention_neo( q_start = int(b_start_loc[b].item()) q_blk = q[q_start : q_start + new] # [M, Hq, D] - image_tag = b_image_token_tag[q_start : q_start + new].to(torch.bool) + q_gid = position_ids_0[q_start : q_start + new].to(torch.int64) + q_image_end = b_image_token_end[q_start : q_start + new].to(torch.int64) token_locs = req_to_token_indexs[req_idx, :total].to(torch.int64) k_blk = k[token_locs] # [total, Hk, D] @@ -70,7 +72,16 @@ def torch_reference_context_attention_neo( q_pos = torch.arange(prompt, total, device=device, dtype=torch.int64) # [M] k_pos = torch.arange(0, total, device=device, dtype=torch.int64) # [total] causal = k_pos[None, :] <= q_pos[:, None] - allow = causal | image_tag[:, None] + + k_gid = torch.full((total,), -1, device=device, dtype=torch.int64) + k_gid[prompt:total] = q_gid + same_image = ( + (q_image_end[:, None] > 0) + & (k_pos[None, :] >= prompt) + & (k_pos[None, :] < q_image_end[:, None]) + & (q_gid[:, None] == k_gid[None, :]) + ) + allow = causal | same_image out_blk = torch.empty_like(q_blk) for h in range(Hq): @@ -137,34 +148,39 @@ def _build_inputs( req_to_token_indexs[req_id, :L] = pool[p : p + L].to(torch.int32) p += L - # Randomly place contiguous image-token spans inside each batch's NEW region. - b_image_token_tag = torch.zeros(sum_new, dtype=torch.bool) + b_image_token_end = torch.zeros(sum_new, dtype=torch.int32) + position_ids_0 = torch.empty(sum_new, dtype=torch.int32) for i in range(batch): M = int(new_lens[i].item()) + P = int(prompt_lens[i].item()) + start_pack = int(b_start_loc[i].item()) + position_ids_0[start_pack : start_pack + M] = torch.arange(P, P + M, dtype=torch.int32) if M < 2: continue if torch.rand((), generator=g).item() > image_prob: continue n_spans = int(torch.randint(1, num_image_spans_max + 1, (1,), generator=g).item()) - start_pack = int(b_start_loc[i].item()) + cursor = 0 for _ in range(n_spans): - span_len = int(torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item()) - span_len = min(span_len, M) - s_rel = int(torch.randint(0, M - span_len + 1, (1,), generator=g).item()) - b_image_token_tag[start_pack + s_rel : start_pack + s_rel + span_len] = True + remaining = M - cursor + if remaining <= 0: + break + gap = int(torch.randint(0, remaining, (1,), generator=g).item()) + s_rel = cursor + gap + max_span_len = min(image_span_len_max, M - s_rel) + if max_span_len <= 0: + break + span_len = int(torch.randint(1, max_span_len + 1, (1,), generator=g).item()) + e_rel = s_rel + span_len + image_gid = P + s_rel + image_end = P + e_rel + position_ids_0[start_pack + s_rel : start_pack + e_rel] = image_gid + b_image_token_end[start_pack + s_rel : start_pack + e_rel] = image_end + cursor = e_rel b_seq_len = total_lens.to(torch.int32) b_prompt_cache_len = prompt_lens.to(torch.int32) - # position_ids[0]: kernel API still requires it even though its current - # mask logic only reads b_image_token_tag. - position_ids_0 = torch.empty(sum_new, dtype=torch.int32) - for i in range(batch): - M = int(new_lens[i].item()) - P = int(prompt_lens[i].item()) - s = int(b_start_loc[i].item()) - position_ids_0[s : s + M] = torch.arange(P, P + M, dtype=torch.int32) - q = torch.randn((sum_new, Hq, D), dtype=dtype, device=device) k = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device) v = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device) @@ -182,13 +198,13 @@ def _build_inputs( b_prompt_cache_len=b_prompt_cache_len.to(device), max_new_len=max_new_len, req_to_token_indexs=req_to_token_indexs.to(device), - b_image_token_tag=b_image_token_tag.to(device), + b_image_token_end=b_image_token_end.to(device), new_lens=new_lens, prompt_lens=prompt_lens, ) -def _report_per_batch_error(out_triton, out_ref, new_lens, b_start_loc, image_tag, tag=""): +def _report_per_batch_error(out_triton, out_ref, new_lens, b_start_loc, image_token_end, tag=""): print(f"\n[{tag}] per-batch error breakdown (abs / rel / cos):") for i in range(new_lens.shape[0]): s = int(b_start_loc[i].item()) @@ -201,7 +217,7 @@ def _report_per_batch_error(out_triton, out_ref, new_lens, b_start_loc, image_ta denom = b.abs().max().item() + 1e-6 rel_err = abs_err / denom cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() - n_img = int(image_tag[s : s + m].sum().item()) + n_img = int((image_token_end[s : s + m] > 0).sum().item()) print( f" batch {i:02d} | M={m:4d} | image_tokens={n_img:4d} | " f"max_abs={abs_err:.4e} | max_rel={rel_err:.4e} | cos={cos:.6f}" @@ -249,7 +265,7 @@ def _run_case( inputs["b_prompt_cache_len"], inputs["max_new_len"], inputs["req_to_token_indexs"], - inputs["b_image_token_tag"], + inputs["b_image_token_end"], ) out_triton = inputs["o"] @@ -257,12 +273,13 @@ def _run_case( inputs["q"], inputs["k"], inputs["v"], + inputs["position_ids_0"], inputs["b_req_idx"], inputs["b_start_loc"], inputs["b_seq_len"], inputs["b_prompt_cache_len"], inputs["req_to_token_indexs"], - inputs["b_image_token_tag"], + inputs["b_image_token_end"], ) a = out_triton.float() @@ -272,8 +289,8 @@ def _run_case( rel_err = abs_err / denom cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item() - n_image = int(inputs["b_image_token_tag"].sum().item()) - n_tokens = int(inputs["b_image_token_tag"].numel()) + n_image = int((inputs["b_image_token_end"] > 0).sum().item()) + n_tokens = int(inputs["b_image_token_end"].numel()) if verbose: print( f"\ncase: batch={batch} Hq={Hq} Hk={Hk} D={D} dtype={dtype} " @@ -289,7 +306,7 @@ def _run_case( out_ref, inputs["new_lens"], inputs["b_start_loc"], - inputs["b_image_token_tag"], + inputs["b_image_token_end"], tag=f"seed={seed}", ) From c09bd8fd76ee73f7d66adabfca4ca37107e9aa0e Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 8 May 2026 09:13:33 +0000 Subject: [PATCH 41/41] delete position ids --- .../layer_infer/transformer_layer_infer.py | 1 - .../layer_infer/transformer_layer_infer.py | 1 - .../context_attention_fwd_neo.py | 51 +------------------ .../test_context_attention_fwd_neo.py | 22 +------- 4 files changed, 4 insertions(+), 71 deletions(-) diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py index e9623b99a6..cf88868a8a 100644 --- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py @@ -106,7 +106,6 @@ def _context_attention_kernel( kv[:, 0 : self.tp_k_head_num_, :], kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] infer_state.b_req_idx, infer_state.b_q_start_loc, infer_state.b_seq_len, diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py index 271c9592c3..9a791ff221 100644 --- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py @@ -104,7 +104,6 @@ def _context_attention_kernel( kv[:, 0 : self.tp_k_head_num_, :], kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), - infer_state.position_ids[0], # [0,0,1,2,3,3,3,4] infer_state.b_req_idx, infer_state.b_q_start_loc, infer_state.b_seq_len, diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py index 8cefe486c0..1016e5a3c1 100644 --- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py +++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py @@ -13,7 +13,6 @@ def _fwd_kernel( V, sm_scale, Out, - position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0] B_Start_Loc, B_Seqlen, Req_to_tokens, @@ -79,7 +78,6 @@ def _fwd_kernel( acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32) # absolute q positions in the request q_pos = prompt_cache_len + offs_m # [M] - q_gid = tl.load(position_ids + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=-1) q_image_end = tl.load(b_image_token_end + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=0) causal_end = tl.minimum(prompt_cache_len + block_start_loc + BLOCK_M, total_len) @@ -91,13 +89,6 @@ def _fwd_kernel( k_pos = start_n + offs_n # [N] k_valid = k_pos < block_end_loc - k_in_new = k_pos >= prompt_cache_len - k_rel = k_pos - prompt_cache_len - k_gid = tl.load( - position_ids + cur_batch_in_all_start_index + k_rel, - mask=k_valid & k_in_new, - other=-2, - ) # map logical pos -> mem_index (for K/V) kv_loc = tl.load( @@ -113,12 +104,7 @@ def _fwd_kernel( qk += tl.dot(q, k) causal_mask = q_pos[:, None] >= k_pos[None, :] - image_mask = ( - (q_image_end[:, None] > 0) - & k_in_new[None, :] - & (k_pos[None, :] < q_image_end[:, None]) - & (q_gid[:, None] == k_gid[None, :]) - ) + image_mask = k_pos[None, :] < q_image_end[:, None] mask = (causal_mask | image_mask) & k_valid[None, :] qk = tl.where(mask, qk * sm_scale, -1.0e8) @@ -157,7 +143,6 @@ def context_attention_fwd_neo( k, v, o, - position_ids, # 1D packed like q (only NEW tokens) b_req_idx, b_start_loc, b_seq_len, @@ -166,9 +151,6 @@ def context_attention_fwd_neo( req_to_token_indexs, b_image_token_end, ): - # minimal safety: position_ids must cover packed q rows - assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0]) - BLOCK_M = 128 if not is_tesla() else 64 Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] @@ -191,7 +173,6 @@ def context_attention_fwd_neo( v, sm_scale, o, - position_ids, b_start_loc, b_seq_len, req_to_token_indexs, @@ -227,7 +208,6 @@ def reference_attention( q, k, v, - position_ids_q, # 1D packed like q (only NEW tokens) b_image_token_end, b_req_idx, b_start_loc, @@ -253,7 +233,6 @@ def reference_attention( q_start = int(b_start_loc[b].item()) q_blk = q[q_start : q_start + new_len] # [M, Hq, D] - gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M] image_end_new = b_image_token_end[q_start : q_start + new_len].to(torch.int64) # [M] # gather K/V for full request by logical pos -> mem_index @@ -272,22 +251,7 @@ def reference_attention( # build allow mask: # causal always allow = k_pos[None, :] <= q_pos[:, None] - - # full-attn only inside NEW part by gid - # compare only when k_pos in NEW - k_in_new = k_pos >= prompt_len - k_rel = (k_pos - prompt_len).clamp_min(0) # [L] - # map k_rel to gid_new, but only valid where k_in_new - k_gid = torch.empty((total_len,), device=device, dtype=torch.int64) - k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new - k_gid[k_in_new] = gid_new[k_rel[k_in_new]] - - allow = allow | ( - (image_end_new[:, None] > 0) - & k_in_new[None, :] - & (k_pos[None, :] < image_end_new[:, None]) - & (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :]) - ) + allow = allow | (k_pos[None, :] < image_end_new[:, None]) # scores: [Hq, M, L] q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D] @@ -355,25 +319,18 @@ def make_test_case( req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) p += L - # position_ids_q: only NEW tokens, packed like q - position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32) b_image_token_end = torch.zeros((sum_q,), device=device, dtype=torch.int32) for b in range(batch): M = int(new_lens[b].item()) P = int(prompt_lens[b].item()) start = int(b_start_loc[b].item()) - gid = torch.arange(P, P + M, device=device, dtype=torch.int32) - # make one repeated block inside NEW part to simulate image tokens if M >= 4 and torch.rand((), device=device).item() > 0.3: s = int(torch.randint(0, M - 2, (1,), device=device).item()) e = min(M, s + 3) - gid[s:e] = gid[s] b_image_token_end[start + s : start + e] = P + e - position_ids_q[start : start + M] = gid - q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) @@ -384,7 +341,6 @@ def make_test_case( k, v, o, - position_ids_q, b_image_token_end, b_req_idx, b_start_loc, @@ -401,7 +357,6 @@ def check_once(device="cuda", dtype=torch.float16, seed=0): k, v, o, - position_ids_q, b_image_token_end, b_req_idx, b_start_loc, @@ -416,7 +371,6 @@ def check_once(device="cuda", dtype=torch.float16, seed=0): k, v, o, - position_ids_q, b_req_idx, b_start_loc, b_seq_len, @@ -430,7 +384,6 @@ def check_once(device="cuda", dtype=torch.float16, seed=0): q, k, v, - position_ids_q, b_image_token_end, b_req_idx, b_start_loc, diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py index 730d4263ee..5cb8bd58ff 100644 --- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py +++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py @@ -5,7 +5,7 @@ batch element we gather K/V for the whole request (prompt + new tokens) via ``req_to_token_indexs`` and apply:: - allow[m, k] = (k <= q_pos[m]) OR same_image_group[m, k] + allow[m, k] = (k <= q_pos[m]) OR (k < image_end[m]) i.e. normal queries are causal, and image-token queries can only see future tokens from the same image span. If the Triton kernel disagrees with this @@ -35,7 +35,6 @@ def torch_reference_context_attention_neo( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - position_ids_0: torch.Tensor, b_req_idx: torch.Tensor, b_start_loc: torch.Tensor, b_seq_len: torch.Tensor, @@ -62,7 +61,6 @@ def torch_reference_context_attention_neo( q_start = int(b_start_loc[b].item()) q_blk = q[q_start : q_start + new] # [M, Hq, D] - q_gid = position_ids_0[q_start : q_start + new].to(torch.int64) q_image_end = b_image_token_end[q_start : q_start + new].to(torch.int64) token_locs = req_to_token_indexs[req_idx, :total].to(torch.int64) @@ -72,16 +70,7 @@ def torch_reference_context_attention_neo( q_pos = torch.arange(prompt, total, device=device, dtype=torch.int64) # [M] k_pos = torch.arange(0, total, device=device, dtype=torch.int64) # [total] causal = k_pos[None, :] <= q_pos[:, None] - - k_gid = torch.full((total,), -1, device=device, dtype=torch.int64) - k_gid[prompt:total] = q_gid - same_image = ( - (q_image_end[:, None] > 0) - & (k_pos[None, :] >= prompt) - & (k_pos[None, :] < q_image_end[:, None]) - & (q_gid[:, None] == k_gid[None, :]) - ) - allow = causal | same_image + allow = causal | (k_pos[None, :] < q_image_end[:, None]) out_blk = torch.empty_like(q_blk) for h in range(Hq): @@ -149,12 +138,10 @@ def _build_inputs( p += L b_image_token_end = torch.zeros(sum_new, dtype=torch.int32) - position_ids_0 = torch.empty(sum_new, dtype=torch.int32) for i in range(batch): M = int(new_lens[i].item()) P = int(prompt_lens[i].item()) start_pack = int(b_start_loc[i].item()) - position_ids_0[start_pack : start_pack + M] = torch.arange(P, P + M, dtype=torch.int32) if M < 2: continue if torch.rand((), generator=g).item() > image_prob: @@ -172,9 +159,7 @@ def _build_inputs( break span_len = int(torch.randint(1, max_span_len + 1, (1,), generator=g).item()) e_rel = s_rel + span_len - image_gid = P + s_rel image_end = P + e_rel - position_ids_0[start_pack + s_rel : start_pack + e_rel] = image_gid b_image_token_end[start_pack + s_rel : start_pack + e_rel] = image_end cursor = e_rel @@ -191,7 +176,6 @@ def _build_inputs( k=k, v=v, o=o, - position_ids_0=position_ids_0.to(device), b_req_idx=b_req_idx.to(device), b_start_loc=b_start_loc.to(device), b_seq_len=b_seq_len.to(device), @@ -258,7 +242,6 @@ def _run_case( inputs["k"], inputs["v"], inputs["o"], - inputs["position_ids_0"], inputs["b_req_idx"], inputs["b_start_loc"], inputs["b_seq_len"], @@ -273,7 +256,6 @@ def _run_case( inputs["q"], inputs["k"], inputs["v"], - inputs["position_ids_0"], inputs["b_req_idx"], inputs["b_start_loc"], inputs["b_seq_len"],