From 7c1f2d1b19225118cbf51353995e56bff1b0c4fa Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Fri, 27 Mar 2026 07:44:37 +0000
Subject: [PATCH 01/41] add neo_chat and neo_chat_moe
---
lightllm/models/__init__.py | 2 +
lightllm/models/llama/model.py | 49 +-
lightllm/models/neo_chat/__init__.py | 0
.../models/neo_chat/layer_infer/__init__.py | 0
.../layer_infer/transformer_layer_infer.py | 117 +++++
.../models/neo_chat/layer_weights/__init__.py | 0
.../pre_and_post_layer_weight.py | 23 +
.../layer_weights/transformer_layer_weight.py | 57 +++
lightllm/models/neo_chat/model.py | 53 +++
lightllm/models/neo_chat_moe/__init__.py | 0
lightllm/models/neo_chat_moe/infer_struct.py | 103 +++++
.../neo_chat_moe/layer_infer/__init__.py | 0
.../layer_infer/transformer_layer_infer.py | 117 +++++
.../neo_chat_moe/layer_weights/__init__.py | 0
.../pre_and_post_layer_weight.py | 23 +
.../layer_weights/transformer_layer_weight.py | 48 ++
lightllm/models/neo_chat_moe/model.py | 192 ++++++++
lightllm/models/neo_chat_moe/neo_visual.py | 281 ++++++++++++
.../neo_chat_moe/triton_kernel/__init__.py | 0
.../context_attention_fwd_neo.py | 430 ++++++++++++++++++
.../triton_kernel/get_neo_position.py | 191 ++++++++
.../models/neo_chat_moe/vision_process.py | 141 ++++++
.../visualserver/model_infer/model_rpc.py | 3 +
23 files changed, 1826 insertions(+), 4 deletions(-)
create mode 100644 lightllm/models/neo_chat/__init__.py
create mode 100644 lightllm/models/neo_chat/layer_infer/__init__.py
create mode 100644 lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
create mode 100644 lightllm/models/neo_chat/layer_weights/__init__.py
create mode 100644 lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py
create mode 100644 lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py
create mode 100644 lightllm/models/neo_chat/model.py
create mode 100644 lightllm/models/neo_chat_moe/__init__.py
create mode 100644 lightllm/models/neo_chat_moe/infer_struct.py
create mode 100644 lightllm/models/neo_chat_moe/layer_infer/__init__.py
create mode 100644 lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
create mode 100644 lightllm/models/neo_chat_moe/layer_weights/__init__.py
create mode 100644 lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py
create mode 100644 lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py
create mode 100644 lightllm/models/neo_chat_moe/model.py
create mode 100644 lightllm/models/neo_chat_moe/neo_visual.py
create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/__init__.py
create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
create mode 100644 lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
create mode 100644 lightllm/models/neo_chat_moe/vision_process.py
diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py
index f2e29d4a88..a24cbf140a 100644
--- a/lightllm/models/__init__.py
+++ b/lightllm/models/__init__.py
@@ -39,4 +39,6 @@
)
from lightllm.models.gpt_oss.model import GptOssTpPartModel
from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel
+from lightllm.models.neo_chat_moe.model import NeoTpMOEPartModel
+from lightllm.models.neo_chat.model import NeoTpPartModel
from .registry import get_model, get_model_class
diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py
index c104ebccc9..1c0277e59b 100644
--- a/lightllm/models/llama/model.py
+++ b/lightllm/models/llama/model.py
@@ -73,10 +73,8 @@ def _init_custom(self):
"""
rope_scaling = self.config.get("rope_scaling", None)
if rope_scaling is None:
- self._init_to_get_rotary()
- return
-
- if "rope_type" in rope_scaling:
+ scaling_type = "default"
+ elif "rope_type" in rope_scaling:
scaling_type = rope_scaling["rope_type"]
elif "type" in rope_scaling:
scaling_type = rope_scaling["type"]
@@ -96,6 +94,8 @@ def _init_custom(self):
self._init_to_get_rotary()
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+ if "rope_theta_hw" in self.config:
+ self._init_to_get_hw_rotary()
return
def _init_to_get_rotary(self, default_base=10000):
@@ -301,3 +301,44 @@ def _init_to_get_llama3_rotary(self, default_base=10000):
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
+
+ def _init_to_get_hw_rotary(self, default_base=10000):
+ partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2)
+ if self.config.get("rope_scaling", {}) is None:
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
+
+ base = self.config.get("rope_theta_hw", float(default_base))
+ if "max_sequence_length" in self.config:
+ max_seq_len = self.config["max_sequence_length"]
+ else:
+ max_position_embeddings = self.config.get(
+ "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384
+ )
+ max_seq_len = max_position_embeddings * rope_scaling_factor
+
+ # NTK
+ try:
+ ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
+ assert ntk_alpha >= 1
+ if ntk_alpha > 1:
+ logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
+ max_seq_len *= ntk_alpha
+ base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
+ except:
+ pass
+
+ full_inv_freq = 1.0 / (
+ base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
+ )
+ inv_freq = full_inv_freq[::2]
+ t = (
+ torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
+ / rope_scaling_factor
+ )
+ freqs = torch.outer(t, inv_freq)
+
+ self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda()
+ self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda()
+ return
diff --git a/lightllm/models/neo_chat/__init__.py b/lightllm/models/neo_chat/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat/layer_infer/__init__.py b/lightllm/models/neo_chat/layer_infer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
new file mode 100644
index 0000000000..ec181a0b8d
--- /dev/null
+++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,117 @@
+import torch
+from functools import partial
+from typing import Tuple
+from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo
+from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd
+from lightllm.models.qwen3.layer_infer.transformer_layer_infer import Qwen3TransformerLayerInfer
+from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight
+from lightllm.distributed import all_reduce
+import torch.distributed as dist
+from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
+from lightllm.common.basemodel.attention.base_att import AttControl
+
+
+class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+ return
+
+ def _bind_attention(self):
+ self._context_attention_kernel = self._context_attention_kernel
+ self._token_attention_kernel = self._token_attention_kernel
+ return
+
+ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatTransformerLayerWeight):
+ input = input.view(-1, self.embed_dim_)
+
+ qkv = layer_weight.qkv_proj.mm(input)
+ q, cache_kv = qkv.split(
+ [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1
+ )
+ q = q.view(q.shape[0], self.tp_q_head_num_, self.head_dim_)
+ q_t, q_hw = q.chunk(2, dim=-1)
+
+ cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+ k = cache_kv[:, : self.tp_k_head_num_, :]
+ v = cache_kv[:, self.tp_k_head_num_ :, :]
+ k_t, k_hw = k.chunk(2, dim=-1)
+
+ q_t_2d = q_t.reshape(q.shape[0], -1)
+ q_hw_2d = q_hw.reshape(q.shape[0], -1)
+ k_t_2d = k_t.reshape(k.shape[0], -1)
+ k_hw_2d = k_hw.reshape(k.shape[0], -1)
+ layer_weight.qk_norm_weight_(q_t_2d, k_t_2d, eps=self.eps_)
+ layer_weight.qk_hw_norm_weight_(q_hw_2d, k_hw_2d, eps=self.eps_)
+
+ q_t = q_t_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+ q_hw = q_hw_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+ q_h, q_w = q_hw.chunk(2, dim=-1)
+
+ k_t = k_t_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+ k_hw = k_hw_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+ k_h, k_w = k_hw.chunk(2, dim=-1)
+
+ rotary_emb_fwd(
+ q_t,
+ k_t,
+ infer_state.position_cos,
+ infer_state.position_sin,
+ )
+ rotary_emb_fwd(
+ q_h,
+ k_h,
+ infer_state.position_cos_h,
+ infer_state.position_sin_h,
+ )
+ rotary_emb_fwd(
+ q_w,
+ k_w,
+ infer_state.position_cos_w,
+ infer_state.position_sin_w,
+ )
+
+ q = torch.cat([q_t, q_h, q_w], dim=-1)
+ q = q.reshape(q.shape[0], -1)
+
+ k = torch.cat([k_t, k_h, k_w], dim=-1)
+ cache_kv = torch.cat([k, v], dim=1)
+ return q, cache_kv
+
+ def _context_attention_kernel(
+ self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
+ ) -> torch.Tensor:
+ o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
+ kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
+ context_attention_fwd_neo(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ kv[:, 0 : self.tp_k_head_num_, :],
+ kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
+ o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
+ infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
+ infer_state.b_req_idx,
+ infer_state.b_q_start_loc,
+ infer_state.b_seq_len,
+ infer_state.b_ready_cache_len,
+ infer_state.max_q_seq_len,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_image_token_tag,
+ )
+ return o_tensor
+
+ def _token_attention_kernel(
+ self,
+ q: torch.Tensor,
+ infer_state: NeoChatInferStateInfo,
+ layer_weight: NeoChatTransformerLayerWeight,
+ ) -> torch.Tensor:
+ _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_)
+ _q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
+ att_control = AttControl()
+ # att_control.mla_decode_dict["softmax_scale"] = 1.0 / (self.head_dim_ ** 0.5)
+ o_tensor = infer_state.decode_att_state.decode_att(
+ q=_q, k=_k, v=_v, att_control=att_control, alloc_func=self.alloc_tensor
+ )
+ o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_)[:, :, : self.head_dim_].contiguous()
+ return o_tensor
diff --git a/lightllm/models/neo_chat/layer_weights/__init__.py b/lightllm/models/neo_chat/layer_weights/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 0000000000..e6489f39af
--- /dev/null
+++ b/lightllm/models/neo_chat/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,23 @@
+import torch
+import numpy as np
+from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
+
+# add key: language_model.xxx -> xxx
+# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
+def rename_weight_keys(weights):
+ prefix = "language_model."
+ keys = list(weights.keys())
+ for k in keys:
+ if prefix in k:
+ weights[k.replace(prefix, "")] = weights.pop(k)
+
+
+class NeoChatPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+ return
+
+ def load_hf_weights(self, weights):
+ rename_weight_keys(weights)
+ super().load_hf_weights(weights)
+ return
diff --git a/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py
new file mode 100644
index 0000000000..8351369fd8
--- /dev/null
+++ b/lightllm/models/neo_chat/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,57 @@
+from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ QKRMSNORMWeight,
+ RMSNormWeight,
+ QKVROWNMMWeight,
+)
+
+
+class NeoChatTransformerLayerWeight(Qwen3TransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
+ super().__init__(layer_num, data_type, network_config, quant_cfg)
+ return
+
+ def _init_weight_names(self):
+ super()._init_weight_names()
+ self._q_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_hw.weight"
+ self._k_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_hw.weight"
+
+ def _init_qkv(self):
+ in_dim = self.n_embed
+ self.qkv_proj = QKVROWNMMWeight(
+ in_dim=in_dim,
+ q_head_num=self.q_head_num_,
+ kv_head_num=self.k_head_num_,
+ head_dim=self.head_dim,
+ weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name],
+ data_type=self.data_type_,
+ bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name],
+ quant_method=self.get_quant_method("qkv_proj"),
+ )
+
+ def _init_norm(self):
+ hidden_size = self.network_config_["hidden_size"]
+ self.att_norm_weight_ = RMSNormWeight(
+ dim=hidden_size,
+ weight_name=self._att_norm_weight_name,
+ data_type=self.data_type_,
+ )
+ self.ffn_norm_weight_ = RMSNormWeight(
+ dim=hidden_size,
+ weight_name=self._ffn_norm_weight_name,
+ data_type=self.data_type_,
+ )
+
+ self.qk_norm_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ q_weight_name=self._q_norm_name,
+ k_weight_name=self._k_norm_name,
+ data_type=self.data_type_,
+ )
+
+ self.qk_hw_norm_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ q_weight_name=self._q_norm_hw_name,
+ k_weight_name=self._k_norm_hw_name,
+ data_type=self.data_type_,
+ )
diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py
new file mode 100644
index 0000000000..14d9f96dc7
--- /dev/null
+++ b/lightllm/models/neo_chat/model.py
@@ -0,0 +1,53 @@
+import os
+import json
+from lightllm.common.build_utils import repair_config
+from lightllm.models.registry import ModelRegistry, llm_model_type_is
+from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
+from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer
+from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer
+from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight
+from lightllm.models.qwen2_vl.model import QWen2VLTokenizer
+from lightllm.models.qwen3.model import Qwen3TpPartModel
+from lightllm.server.core.objs import SamplingParams
+from lightllm.models.qwen3_moe.model import Qwen3MOEModel
+from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
+from lightllm.models.neo_chat_moe.vision_process import smart_resize
+from lightllm.models.internvl.model import InternvlTokenizer
+from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
+from lightllm.models.neo_chat.layer_infer.transformer_layer_infer import NeoChatTransformerLayerInfer
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.models.neo_chat.layer_weights.transformer_layer_weight import NeoChatTransformerLayerWeight
+from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight
+from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+
+
+@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3"))
+class NeoTpPartModel(Qwen3TpPartModel):
+
+ pre_layer_infer_class = LlamaMultimodalPreLayerInfer
+ transformer_layer_infer_class = NeoChatTransformerLayerInfer
+
+ pre_and_post_weight_class = NeoChatPreAndPostLayerWeight
+ transformer_weight_class = NeoChatTransformerLayerWeight
+
+ infer_state_class = NeoChatInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+ return
+
+ def _init_inferstate_cls(self):
+ pass
+
+ def _init_config(self):
+ with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
+ all_config = json.load(json_file)
+ self.config = all_config["llm_config"]
+ # rename keys
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+ return
diff --git a/lightllm/models/neo_chat_moe/__init__.py b/lightllm/models/neo_chat_moe/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py
new file mode 100644
index 0000000000..add8abda08
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/infer_struct.py
@@ -0,0 +1,103 @@
+from typing import Optional, List
+import torch
+import numpy as np
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.common.req_manager import ReqManager
+from lightllm.models.neo_chat_moe.triton_kernel.get_neo_position import get_neo_position_triton
+from lightllm.models.llama.model import LlamaTpPartModel
+
+
+class NeoChatInferStateInfo(LlamaInferStateInfo):
+ def __init__(self):
+ super().__init__()
+ self.position_cos = None
+ self.position_sin = None
+ self.position_cos_h = None
+ self.position_sin_h = None
+ self.position_cos_w = None
+ self.position_sin_w = None
+
+ def init_some_extra_state(self, model: LlamaTpPartModel):
+ LlamaInferStateInfo.init_some_extra_state(self, model)
+ if self.is_prefill:
+ self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda(
+ non_blocking=True
+ )
+ self.position_ids = self.get_neo_position(self.multimodal_params)
+ else:
+ b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])]
+ for batch_idx, p in enumerate(self.multimodal_params):
+ position_delta = 0
+ for image in p["images"]:
+ position_delta += image["grid_thwd"][3]
+ b_position_delta[batch_idx] = position_delta
+ position_ids = self.position_ids + torch.tensor(b_position_delta, device=self.position_ids.device)
+ self.position_ids = position_ids.unsqueeze(0).expand(3, -1).clone()
+ self.position_ids[1:].zero_()
+
+ self.position_ids = self.position_ids.contiguous()
+ self.position_cos = model._cos_cached[self.position_ids[0]]
+ self.position_sin = model._sin_cached[self.position_ids[0]]
+ self.position_cos_h = model._hw_cos_cached[self.position_ids[1]]
+ self.position_sin_h = model._hw_sin_cached[self.position_ids[1]]
+ self.position_cos_w = model._hw_cos_cached[self.position_ids[2]]
+ self.position_sin_w = model._hw_sin_cached[self.position_ids[2]]
+ return
+
+ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor:
+ if len(multimodal_params) == 0:
+ position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
+ position_ids[0].copy_(self.position_ids)
+ return position_ids
+ b_image_start_idx = []
+ b_image_nums = []
+ b_image_start_num = []
+ b_image_len = []
+ image_start_num = 0
+ b_image_thwd = []
+
+ # pad multimodal_params to batch size.
+ batch_size = self.b_q_seq_len.shape[0]
+ multimodal_params = multimodal_params + [
+ {"images": [], "audios": []} for _ in range(batch_size - len(multimodal_params))
+ ]
+
+ for _, p in enumerate(multimodal_params):
+ images = p.get("images", [])
+ for img in images:
+ b_image_start_idx.append(img["start_idx"])
+ # a = img["start_idx"]
+ # print(f"img start_idx: {a}")
+ b_image_len.append(img["token_num"])
+ b_image_thwd.append(img["grid_thwd"])
+ b_image_nums.append(len(images))
+ b_image_start_num.append(image_start_num)
+ image_start_num += len(images)
+
+ # 没有任何图片
+ if image_start_num == 0:
+ position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
+ position_ids[0].copy_(self.position_ids)
+ return position_ids.contiguous()
+ b_image_start_idx = torch.tensor(b_image_start_idx, device="cpu").cuda(non_blocking=True)
+ b_image_thwd = torch.tensor(b_image_thwd, device="cpu").cuda(non_blocking=True) # image_num x 4
+ b_image_nums = torch.tensor(b_image_nums, device="cpu").cuda(non_blocking=True)
+ b_image_start_num = torch.tensor(b_image_start_num, device="cpu").cuda(non_blocking=True)
+ b_image_len = torch.tensor(b_image_len, device="cpu").cuda(non_blocking=True)
+
+ position_ids = self.position_ids.new_zeros((3, self.position_ids.size(0)))
+ position_ids[0].copy_(self.position_ids)
+
+ get_neo_position_triton(
+ b_image_start_idx=b_image_start_idx,
+ b_image_thwd=b_image_thwd,
+ b_image_nums=b_image_nums,
+ b_image_start_num=b_image_start_num,
+ b_image_len=b_image_len,
+ position_ids=position_ids,
+ b_ready_cache_len=self.b_ready_cache_len,
+ b_q_seq_len=self.b_q_seq_len,
+ b_start_loc=self.b_q_start_loc,
+ b_image_token_tag=self.b_image_token_tag,
+ )
+ return position_ids
diff --git a/lightllm/models/neo_chat_moe/layer_infer/__init__.py b/lightllm/models/neo_chat_moe/layer_infer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
new file mode 100644
index 0000000000..4c4d8a22ab
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,117 @@
+import torch
+from functools import partial
+from typing import Tuple
+from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import context_attention_fwd_neo
+from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd
+from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer
+from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight
+from lightllm.distributed import all_reduce
+import torch.distributed as dist
+from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
+from lightllm.common.basemodel.attention.base_att import AttControl
+
+
+class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+ return
+
+ def _bind_attention(self):
+ self._context_attention_kernel = self._context_attention_kernel
+ self._token_attention_kernel = self._token_attention_kernel
+ return
+
+ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoChatMOETransformerLayerWeight):
+ input = input.view(-1, self.embed_dim_)
+
+ qkv = layer_weight.qkv_proj.mm(input)
+ q, cache_kv = qkv.split(
+ [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1
+ )
+ q = q.view(q.shape[0], self.tp_q_head_num_, self.head_dim_)
+ q_t, q_hw = q.chunk(2, dim=-1)
+
+ cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+ k = cache_kv[:, : self.tp_k_head_num_, :]
+ v = cache_kv[:, self.tp_k_head_num_ :, :]
+ k_t, k_hw = k.chunk(2, dim=-1)
+
+ q_t_2d = q_t.reshape(q.shape[0], -1)
+ q_hw_2d = q_hw.reshape(q.shape[0], -1)
+ k_t_2d = k_t.reshape(k.shape[0], -1)
+ k_hw_2d = k_hw.reshape(k.shape[0], -1)
+ layer_weight.qk_norm_weight_(q_t_2d, k_t_2d, eps=self.eps_)
+ layer_weight.qk_hw_norm_weight_(q_hw_2d, k_hw_2d, eps=self.eps_)
+
+ q_t = q_t_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+ q_hw = q_hw_2d.view(q.shape[0], self.tp_q_head_num_, self.head_dim_ // 2)
+ q_h, q_w = q_hw.chunk(2, dim=-1)
+
+ k_t = k_t_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+ k_hw = k_hw_2d.view(k.shape[0], self.tp_k_head_num_, self.head_dim_ // 2)
+ k_h, k_w = k_hw.chunk(2, dim=-1)
+
+ rotary_emb_fwd(
+ q_t,
+ k_t,
+ infer_state.position_cos,
+ infer_state.position_sin,
+ )
+ rotary_emb_fwd(
+ q_h,
+ k_h,
+ infer_state.position_cos_h,
+ infer_state.position_sin_h,
+ )
+ rotary_emb_fwd(
+ q_w,
+ k_w,
+ infer_state.position_cos_w,
+ infer_state.position_sin_w,
+ )
+
+ q = torch.cat([q_t, q_h, q_w], dim=-1)
+ q = q.reshape(q.shape[0], -1)
+
+ k = torch.cat([k_t, k_h, k_w], dim=-1)
+ cache_kv = torch.cat([k, v], dim=1)
+ return q, cache_kv
+
+ def _context_attention_kernel(
+ self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
+ ) -> torch.Tensor:
+ o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
+ kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
+ context_attention_fwd_neo(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ kv[:, 0 : self.tp_k_head_num_, :],
+ kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
+ o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
+ infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
+ infer_state.b_req_idx,
+ infer_state.b_q_start_loc,
+ infer_state.b_seq_len,
+ infer_state.b_ready_cache_len,
+ infer_state.max_q_seq_len,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_image_token_tag,
+ )
+ return o_tensor
+
+ def _token_attention_kernel(
+ self,
+ q: torch.Tensor,
+ infer_state: NeoChatInferStateInfo,
+ layer_weight: NeoChatMOETransformerLayerWeight,
+ ) -> torch.Tensor:
+ _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_)
+ _q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
+ att_control = AttControl()
+ # att_control.mla_decode_dict["softmax_scale"] = 1.0 / (self.head_dim_ ** 0.5)
+ o_tensor = infer_state.decode_att_state.decode_att(
+ q=_q, k=_k, v=_v, att_control=att_control, alloc_func=self.alloc_tensor
+ )
+ o_tensor = o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_)[:, :, : self.head_dim_].contiguous()
+ return o_tensor
diff --git a/lightllm/models/neo_chat_moe/layer_weights/__init__.py b/lightllm/models/neo_chat_moe/layer_weights/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 0000000000..4b0eae91c3
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,23 @@
+import torch
+import numpy as np
+from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
+
+# add key: language_model.xxx -> xxx
+# only change keys at PreAndPostLayerWeight load, TransformLayerWeight is correct now
+def rename_weight_keys(weights):
+ prefix = "language_model."
+ keys = list(weights.keys())
+ for k in keys:
+ if prefix in k:
+ weights[k.replace(prefix, "")] = weights.pop(k)
+
+
+class NeoChatMOEPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+ return
+
+ def load_hf_weights(self, weights):
+ rename_weight_keys(weights)
+ super().load_hf_weights(weights)
+ return
diff --git a/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py
new file mode 100644
index 0000000000..d4f985db45
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,48 @@
+from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ QKRMSNORMWeight,
+ ROWMMWeight,
+ RMSNormWeight,
+)
+
+
+class NeoChatMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
+ self._is_merge_kv = network_config.get("merge_kv", True)
+ super().__init__(layer_num, data_type, network_config, quant_cfg)
+ return
+
+ def _init_weight_names(self):
+ super()._init_weight_names()
+ self._q_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.q_norm_hw.weight"
+ self._k_norm_hw_name = f"model.layers.{self.layer_num_}.self_attn.k_norm_hw.weight"
+
+ def _init_qkv(self):
+ super()._init_qkv()
+
+ def _init_norm(self):
+ hidden_size = self.network_config_["hidden_size"]
+ self.att_norm_weight_ = RMSNormWeight(
+ dim=hidden_size,
+ weight_name=self._att_norm_weight_name,
+ data_type=self.data_type_,
+ )
+ self.ffn_norm_weight_ = RMSNormWeight(
+ dim=hidden_size,
+ weight_name=self._ffn_norm_weight_name,
+ data_type=self.data_type_,
+ )
+
+ self.qk_norm_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ q_weight_name=self._q_norm_name,
+ k_weight_name=self._k_norm_name,
+ data_type=self.data_type_,
+ )
+
+ self.qk_hw_norm_weight_ = QKRMSNORMWeight(
+ dim=self.head_dim // 2,
+ q_weight_name=self._q_norm_hw_name,
+ k_weight_name=self._k_norm_hw_name,
+ data_type=self.data_type_,
+ )
diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py
new file mode 100644
index 0000000000..d9f40d7feb
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/model.py
@@ -0,0 +1,192 @@
+import os
+import json
+from lightllm.common.build_utils import repair_config
+from lightllm.models.registry import ModelRegistry, llm_model_type_is
+from lightllm.models.qwen3_vl.infer_struct import Qwen3VLInferStateInfo
+from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import Qwen3VLMultimodalPreLayerInfer
+from lightllm.models.qwen3_vl.layer_infer.transformer_layer_infer import Qwen3VLTransformerLayerInfer
+from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import Qwen3VLPreAndPostLayerWeight
+from lightllm.models.qwen2_vl.model import QWen2VLTokenizer
+from lightllm.models.qwen3.model import Qwen3TpPartModel
+from lightllm.server.core.objs import SamplingParams
+from lightllm.models.qwen3_moe.model import Qwen3MOEModel
+from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem
+from lightllm.models.neo_chat_moe.vision_process import smart_resize
+from lightllm.models.internvl.model import InternvlTokenizer
+from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer
+from lightllm.models.neo_chat_moe.layer_infer.transformer_layer_infer import NeoChatMOETransformerLayerInfer
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.models.neo_chat_moe.layer_weights.transformer_layer_weight import NeoChatMOETransformerLayerWeight
+from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight
+from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
+from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+
+IMG_START_TOKEN = "
"
+IMG_END_TOKEN = ""
+IMG_TOKEN = ""
+AUDIO_START_TOKEN = ""
+
+
+class NeoChatTokenizer(BaseMultiModalTokenizer):
+ def __init__(self, tokenizer, model_cfg, **kwargs):
+ super().__init__(tokenizer)
+ self.tokenizer = tokenizer
+ self.min_pixel = model_cfg.get("vision_config").get("min_pixels")
+ self.max_pixel = model_cfg.get("vision_config").get("max_pixels")
+ self.patch_size = model_cfg.get("vision_config").get("patch_size")
+ self.downsample_ratio = model_cfg.get("vision_config").get("downsample_ratio")
+
+ self.image_token_id = model_cfg.get("image_token_id")
+ self.image_start_tag = IMG_START_TOKEN
+ self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag)
+ self.image_end_tag = IMG_END_TOKEN
+ self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
+ self.conversation_module = self.load_conversion_module(tokenizer.name_or_path)
+ self.template = model_cfg.get("template", "neo1_0")
+
+ def load_conversion_module(self, model_dir: str):
+ import importlib
+
+ conversion_path = os.path.join(model_dir, "conversation.py")
+ if not os.path.exists(conversion_path):
+ return None
+
+ spec = importlib.util.spec_from_file_location("conversation", str(conversion_path))
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module) # must run the module
+ return module
+
+ def init_imageitem_extral_params(
+ self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
+ ):
+ img.extra_params["min_pixels"] = (
+ sampling_params.min_pixels if sampling_params.min_pixels > 0 else self.min_pixel
+ )
+ img.extra_params["max_pixels"] = (
+ sampling_params.max_pixels if sampling_params.max_pixels > 0 else self.max_pixel
+ )
+ assert (
+ img.extra_params["min_pixels"] <= img.extra_params["max_pixels"]
+ ), "min_pixels should be less than or equal to max_pixels"
+ return
+
+ def init_audioitem_extral_params(
+ self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams
+ ):
+ raise NotImplementedError
+
+ def get_audio_token_length(self, audio: AudioItem):
+ raise NotImplementedError
+
+ def get_image_token_length(self, img: ImageItem):
+ width, height = img.image_w, img.image_h
+ resized_height, resized_width = smart_resize(
+ height=height,
+ width=width,
+ factor=int(self.patch_size // self.downsample_ratio),
+ min_pixels=img.extra_params["min_pixels"],
+ max_pixels=img.extra_params["max_pixels"],
+ )
+ grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
+ token_num = int((grid_h * grid_w) * (self.downsample_ratio ** 2))
+ # 这里的grid_h和grid_w需要* self.downsample_ratio么?再仔细看下代码
+ img.grid_thwd = (1, int(grid_h * self.downsample_ratio), int(grid_w * self.downsample_ratio), 1 - token_num)
+ return token_num
+
+ # only change the impl of the encode func:
+ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
+ # TEXTTEXTTEXT --> TEXT
TEXT
TEXT
+ image_tokens = IMG_START_TOKEN + IMG_END_TOKEN
+ if multimodal_params is None:
+ add_special_tokens = kwargs.get("add_special_tokens", True)
+ return self.tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
+ image_count = len(multimodal_params.images)
+ if not kwargs.get("already_tokenized", False):
+ prompt = prompt.replace(IMG_TOKEN, image_tokens, image_count)
+ origin_ids = self.tokenizer.encode(prompt, add_special_tokens=kwargs["add_special_tokens"])
+ else:
+ origin_ids = prompt
+ #
-->
id,id+1...id+num
+ input_ids = []
+ image_id = 0
+ start_idx = 0
+ while True:
+ try:
+ start_idx = origin_ids.index(self.image_start_id)
+ if start_idx + 1 >= len(origin_ids):
+ break
+ if origin_ids[start_idx + 1] == self.image_end_id:
+ input_ids.extend(origin_ids[: start_idx + 1])
+ token_id = multimodal_params.images[image_id].token_id
+ token_num = multimodal_params.images[image_id].token_num
+ multimodal_params.images[image_id].start_idx = len(input_ids)
+ input_ids.extend(range(token_id, token_id + token_num))
+ input_ids.append(self.image_end_id)
+ origin_ids = origin_ids[start_idx + 2 :]
+ image_id += 1
+ else:
+ raise ValueError("image token error")
+ except ValueError:
+ break
+ input_ids.extend(origin_ids)
+ return input_ids
+
+ def _build_t2i_query(self, msg, thinking_content=""):
+ template = self.conversation_module.get_conv_template(self.template)
+ template.append_message(template.roles[0], msg)
+ template.append_message(template.roles[1], None)
+ return template.get_prompt() + thinking_content + IMG_START_TOKEN
+
+ def fix_prompt(self, prompt: str, img_len: int):
+ prompt_img_len = prompt.count(IMG_TOKEN)
+ assert prompt_img_len <= img_len, f"not enough images provided, need {prompt_img_len}, given {img_len}"
+ if prompt_img_len < img_len:
+ return f"{IMG_TOKEN}\n" * (img_len - prompt_img_len) + prompt
+ return prompt
+
+ def get_query_for_it2i(self, prompt: str):
+ image_len = prompt.count(IMG_TOKEN)
+ query_condition = self._build_t2i_query(prompt, thinking_content="\n\n\n\n")
+ query_text_uncondition = self._build_t2i_query(IMG_TOKEN * image_len)
+ question_img_uncondition = self._build_t2i_query("")
+ return query_condition, query_text_uncondition, question_img_uncondition
+
+ def get_query_for_t2i(self, prompt):
+ query_condition = self._build_t2i_query(
+ f"Please generate an image based on the following description: {prompt}",
+ thinking_content="\n\n\n\n",
+ )
+ query_uncondition = self._build_t2i_query(f"")
+ return query_condition, query_uncondition
+
+
+@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3_moe"))
+class NeoTpMOEPartModel(Qwen3MOEModel):
+
+ pre_layer_infer_class = LlamaMultimodalPreLayerInfer
+ transformer_layer_infer_class = NeoChatMOETransformerLayerInfer
+
+ pre_and_post_weight_class = NeoChatMOEPreAndPostLayerWeight
+ transformer_weight_class = NeoChatMOETransformerLayerWeight
+
+ infer_state_class = NeoChatInferStateInfo
+
+ def __init__(self, kvargs):
+ super().__init__(kvargs)
+ return
+
+ def _init_inferstate_cls(self):
+ pass
+
+ def _init_config(self):
+ with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
+ all_config = json.load(json_file)
+ self.config = all_config["llm_config"]
+ # rename keys
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+ return
diff --git a/lightllm/models/neo_chat_moe/neo_visual.py b/lightllm/models/neo_chat_moe/neo_visual.py
new file mode 100644
index 0000000000..a516a99f64
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/neo_visual.py
@@ -0,0 +1,281 @@
+import os
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from typing import List
+from io import BytesIO
+import torch.nn as nn
+from transformers.activations import ACT2FN
+from safetensors import safe_open
+from lightllm.server.multimodal_params import ImageItem
+from transformers.modeling_outputs import BaseModelOutputWithPooling
+from transformers.modeling_utils import PreTrainedModel
+from lightllm.models.neo_chat_moe.vision_process import load_image_native
+from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
+
+
+def apply_rotary_emb_1d(
+ x: torch.Tensor,
+ cos_cached: torch.Tensor,
+ sin_cached: torch.Tensor,
+ positions: torch.Tensor,
+):
+ """对输入张量的一部分应用1D RoPE。"""
+ # x: (..., seq_len, dim_part)
+ # positions: (..., seq_len)
+ # cos_cached: (max_pos, dim_part / 2)
+ cos_cached = cos_cached.to(device=positions.device)
+ sin_cached = sin_cached.to(device=positions.device)
+
+ cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2)
+ sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2)
+
+ x1 = x[..., 0::2]
+ x2 = x[..., 1::2]
+
+ rotated_x1 = x1 * cos - x2 * sin
+ rotated_x2 = x1 * sin + x2 * cos
+
+ x_rotated = torch.empty_like(x)
+ x_rotated[..., 0::2] = rotated_x1
+ x_rotated[..., 1::2] = rotated_x2
+ return x_rotated
+
+
+def apply_2d_rotary_pos_emb(
+ x: torch.Tensor,
+ cos_cached_x: torch.Tensor,
+ sin_cached_x: torch.Tensor,
+ cos_cached_y: torch.Tensor,
+ sin_cached_y: torch.Tensor,
+ abs_positions_x: torch.Tensor,
+ abs_positions_y: torch.Tensor,
+):
+ """应用2D RoPE到输入张量x。"""
+ dim = x.shape[-1]
+ dim_half = dim // 2
+
+ # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向
+ # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致)
+ x_part_1 = x[..., :dim_half]
+ x_part_2 = x[..., dim_half:]
+
+ # 将与 abs_positions_x 相关的旋转应用于 x_part_1
+ rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x)
+ # 将与 abs_positions_y 相关的旋转应用于 x_part_2
+ rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y)
+
+ # 将它们重新拼接起来。确保顺序与你分割时一致。
+ return torch.cat((rotated_part_1, rotated_part_2), dim=-1)
+
+
+def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
+ """
+ Compute patch coordinates (x, y)
+
+ Args:
+ grid_hw: (B, 2) tensor representing (H, W) per image
+ """
+ device = grid_hw.device
+ B = grid_hw.shape[0]
+
+ # Get the number of patches per image
+ H = grid_hw[:, 0]
+ W = grid_hw[:, 1]
+ N = H * W
+ N_total = N.sum()
+
+ # Create the batch index for each patch (B x patch count)
+ patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,)
+
+ # Generate intra-image patch index (row-major order)
+ patch_id_within_image = torch.arange(N_total, device=device)
+ patch_id_within_image = (
+ patch_id_within_image
+ - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample]
+ )
+
+ # Get H/W for each patch according to its image
+ W_per_patch = W[patch_to_sample]
+ abs_x = patch_id_within_image % W_per_patch
+ abs_y = patch_id_within_image // W_per_patch
+
+ return abs_x, abs_y
+
+
+class NeoVisionTransformerPretrainedModel(nn.Module):
+ def __init__(
+ self,
+ kvargs,
+ hidden_size: int = 1024,
+ llm_hidden_size: int = 2048,
+ downsample_ratio: float = 0.5,
+ patch_size: int = 16,
+ num_channels: int = 3,
+ max_position_embeddings_vision: int = 10000,
+ rope_theta_vision: float = 10000.0,
+ min_pixels: int = 65536,
+ max_pixels: int = 2408448,
+ **kwargs,
+ ):
+ super().__init__()
+ self.weight_dir = kvargs["weight_dir"]
+ self.data_type = kvargs.get("data_type", "bfloat16")
+ self.embed_dim = hidden_size
+ self.llm_hidden_size = llm_hidden_size
+ self.patch_size = patch_size
+ self.num_channels = num_channels
+ self.downsample_ratio = downsample_ratio
+ self.downsample_factor = int(1 / downsample_ratio)
+ self.max_position_embeddings_vision = max_position_embeddings_vision
+ self.rope_theta_vision = rope_theta_vision
+ self.rope_dim_part = self.embed_dim // 2
+ self.min_pixels = min_pixels
+ self.max_pixels = max_pixels
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=num_channels, out_channels=self.embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+
+ self.dense_embedding = nn.Conv2d(
+ in_channels=self.embed_dim,
+ out_channels=self.llm_hidden_size,
+ kernel_size=self.downsample_factor,
+ stride=self.downsample_factor,
+ )
+ self.gelu = nn.GELU()
+
+ self.repe_dim_part = self.embed_dim // 2
+ self.cos_x, self.sin_x = self.precompute_rope_freqs_sincos()
+ self.cos_y, self.sin_y = self.precompute_rope_freqs_sincos()
+ self._init_datatype()
+
+ def _init_datatype(self):
+ if isinstance(self.data_type, torch.dtype):
+ return
+ if self.data_type in ["fp16", "float16"]:
+ self.data_type = torch.float16
+ elif self.data_type in ["bf16", "bfloat16"]:
+ self.data_type = torch.bfloat16
+ elif self.data_type in ["fp32", "float32"]:
+ self.data_type = torch.float32
+ else:
+ raise ValueError(f"Unsupport datatype {self.data_type}!")
+ return
+
+ def load_model(self, weight_dir):
+ bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
+ if bin_weight_files:
+ weight_dict = {}
+ for file_ in bin_weight_files:
+ f = torch.load(os.path.join(weight_dir, file_), "cpu")
+ for k, v in f.items():
+ if "vision_model" in k and "fm_modules" not in k:
+ weight_dict[k[len("vision_model.embeddings.") :]] = v
+ else:
+ hf_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".safetensors")]
+ weight_dict = {}
+ for file_ in hf_weight_files:
+ f = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
+ for k in f.keys():
+ if "vision_model" in k and "fm_modules" not in k:
+ weight_dict[k[len("vision_model.embeddings.") :]] = f.get_tensor(k)
+ self.load_state_dict(weight_dict)
+
+ def precompute_rope_freqs_sincos(self):
+ inv_freq = 1.0 / (
+ self.rope_theta_vision ** (torch.arange(0, self.rope_dim_part, 2).float() / self.rope_dim_part)
+ )
+ t = torch.arange(self.max_position_embeddings_vision).type_as(inv_freq)
+ freqs = torch.outer(t, inv_freq)
+ return torch.cos(freqs), torch.sin(freqs)
+
+ def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw):
+ """
+ Apply 2D Rotary Position Embedding to the patch embeddings.
+ """
+ abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device)
+ embeddings = apply_2d_rotary_pos_emb(
+ patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32
+ self.cos_x,
+ self.sin_x,
+ self.cos_y,
+ self.sin_y,
+ abs_pos_x,
+ abs_pos_y,
+ ).to(self.patch_embedding.weight.dtype)
+ return embeddings
+
+ def forward(self, pixel_values: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
+ pixel_values = pixel_values.view(
+ -1,
+ 3,
+ self.patch_size,
+ self.patch_size,
+ )
+ patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim)
+ patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw)
+ assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[
+ 0
+ ], "Grid size and patch embeds size mismatch."
+
+ patches_list = []
+ cur_position = 0
+ for i in range(grid_hw.shape[0]):
+ h, w = grid_hw[i]
+ patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0)
+ patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2))
+ patches_per_img = patches_per_img.permute(0, 2, 3, 1)
+ patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1]))
+ cur_position += h * w
+
+ embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C)
+ assert cur_position == patch_embeds.shape[0]
+ assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2)
+
+ return embeddings
+
+ def encode(self, images: List[ImageItem]):
+ img_tensors = []
+ valid_ids = []
+ valid_id = 0
+ img_grids = []
+ uuids = []
+
+ for i, img in enumerate(images):
+ if isinstance(img, ImageItem):
+ uuids.append(img.uuid)
+ image_data = read_shm(get_shm_name_data(img.uuid))
+ image_data = Image.open(BytesIO(image_data))
+ # a = img.extra_params["min_pixels"]
+ # b = img.extra_params["max_pixels"]
+ # print(f"self.min_pixels is {a} ,max_pixelx is {b}")
+ pixel_values, image_grid_hw = load_image_native(
+ image_data,
+ patch_size=self.patch_size,
+ downsample_ratio=self.downsample_ratio,
+ min_pixels=img.extra_params["min_pixels"],
+ max_pixels=img.extra_params["max_pixels"],
+ )
+ img_tensors.append(pixel_values)
+ img_grids.append(image_grid_hw)
+ else:
+ raise Exception("Unsupport input types: {} for {}".format(type(img), img))
+
+ # must devide merge_length
+ cur_num = int(img_tensors[-1].shape[0] * (self.downsample_ratio ** 2))
+ valid_ids.append([valid_id, valid_id + cur_num])
+ valid_id += cur_num
+
+ if len(img_tensors) <= 0:
+ return None
+
+ imgs = torch.cat(img_tensors, dim=0)
+ grid_hw = torch.cat(img_grids, dim=0)
+
+ pixel_values = imgs.to("cuda", dtype=self.data_type, non_blocking=True)
+ image_grid_hw = grid_hw.to("cuda", non_blocking=True)
+
+ all_img_embeds = self.forward(pixel_values, grid_hw=image_grid_hw)
+
+ return all_img_embeds, uuids, valid_ids
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/__init__.py b/lightllm/models/neo_chat_moe/triton_kernel/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
new file mode 100644
index 0000000000..74ff82cae4
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
@@ -0,0 +1,430 @@
+import math
+import torch
+import triton
+import triton.language as tl
+
+from lightllm.utils.device_utils import is_tesla
+
+
+@triton.jit
+def _fwd_kernel(
+ Q,
+ K,
+ V,
+ sm_scale,
+ Out,
+ position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0]
+ B_Start_Loc,
+ B_Seqlen,
+ Req_to_tokens,
+ B_req_idx,
+ stride_qbs,
+ stride_qh,
+ stride_qd,
+ stride_kbs,
+ stride_kh,
+ stride_kd,
+ stride_vbs,
+ stride_vh,
+ stride_vd,
+ stride_obs,
+ stride_oh,
+ stride_od,
+ stride_req_to_tokens_b,
+ stride_req_to_tokens_s,
+ kv_group_num,
+ b_prompt_cache_len,
+ b_image_token_tag,
+ H: tl.constexpr,
+ QK_HEAD_DIM: tl.constexpr,
+ V_HEAD_DIM: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+):
+ start_m = tl.program_id(0)
+ cur_bh = tl.program_id(1)
+ cur_batch = cur_bh // H
+ cur_head = cur_bh % H
+
+ cur_kv_head = cur_head // kv_group_num
+
+ cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
+ prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)
+ total_len = tl.load(B_Seqlen + cur_batch)
+ cur_batch_seq_len = total_len - prompt_cache_len # NEW len
+ cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
+
+ block_start_loc = BLOCK_M * start_m
+ if block_start_loc >= cur_batch_seq_len:
+ return
+
+ offs_n = tl.arange(0, BLOCK_N)
+ offs_d_qk = tl.arange(0, QK_HEAD_DIM)
+ offs_d_v = tl.arange(0, V_HEAD_DIM)
+ offs_m = block_start_loc + tl.arange(0, BLOCK_M)
+
+ # Q pointers
+ off_q = (
+ (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs
+ + cur_head * stride_qh
+ + offs_d_qk[None, :] * stride_qd
+ )
+
+ q_valid = offs_m < cur_batch_seq_len
+ q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0)
+
+ # online softmax state
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+ acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
+ block_end_loc = total_len
+
+ # absolute q positions in the request
+ q_pos = prompt_cache_len + offs_m # [M]
+ q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False)
+
+ for start_n in range(0, block_end_loc, BLOCK_N):
+ start_n = tl.multiple_of(start_n, BLOCK_N)
+
+ k_pos = start_n + offs_n # [N]
+ k_valid = k_pos < block_end_loc
+
+ # map logical pos -> mem_index (for K/V)
+ kv_loc = tl.load(
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos,
+ mask=k_valid,
+ other=0,
+ ).to(tl.int64)
+
+ # load K
+ off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d_qk[:, None] * stride_kd
+ k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0)
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+ qk += tl.dot(q, k)
+
+ # mask: causal OR same gid (only possible inside NEW part)
+ mask = (q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None]
+ qk = tl.where(mask, qk * sm_scale, -1.0e8)
+
+ # online softmax
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
+ qk -= m_ij[:, None]
+ p = tl.math.exp2(qk)
+ l_ij = tl.sum(p, 1)
+
+ alpha = tl.math.exp2(m_i - m_ij)
+ l_i = l_i * alpha + l_ij
+ acc = acc * alpha[:, None]
+
+ # load V
+ off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d_v[None, :] * stride_vd
+ v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0)
+
+ p = p.to(v.dtype)
+ acc = tl.dot(p, v, acc)
+
+ m_i = m_ij
+
+ acc = acc / l_i[:, None]
+
+ off_o = (
+ (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs
+ + cur_head * stride_oh
+ + offs_d_v[None, :] * stride_od
+ )
+ tl.store(Out + off_o, acc, mask=q_valid[:, None])
+
+
+@torch.no_grad()
+def context_attention_fwd_neo(
+ q,
+ k,
+ v,
+ o,
+ position_ids, # 1D packed like q (only NEW tokens)
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_input_len,
+ req_to_token_indexs,
+ b_image_token_tag,
+):
+ # minimal safety: position_ids must cover packed q rows
+ assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0])
+
+ BLOCK_M = 128 if not is_tesla() else 64
+
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+ assert Lq == Lk and Lk == Lv
+ assert Lk in {16, 32, 64, 128, 256}
+ sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634
+
+ batch, head = b_seq_len.shape[0], q.shape[1]
+ kv_group_num = q.shape[1] // k.shape[1]
+
+ grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1)
+
+ BLOCK_N = BLOCK_M
+ num_warps = 4 if Lk <= 64 else 8
+ num_stages = 1
+
+ _fwd_kernel[grid](
+ q,
+ k,
+ v,
+ sm_scale,
+ o,
+ position_ids,
+ b_start_loc,
+ b_seq_len,
+ req_to_token_indexs,
+ b_req_idx,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k.stride(0),
+ k.stride(1),
+ k.stride(2),
+ v.stride(0),
+ v.stride(1),
+ v.stride(2),
+ o.stride(0),
+ o.stride(1),
+ o.stride(2),
+ req_to_token_indexs.stride(0),
+ req_to_token_indexs.stride(1),
+ kv_group_num=kv_group_num,
+ b_prompt_cache_len=b_prompt_cache_len,
+ b_image_token_tag=b_image_token_tag,
+ H=head,
+ QK_HEAD_DIM=Lk,
+ V_HEAD_DIM=Lk,
+ BLOCK_M=BLOCK_M,
+ BLOCK_N=BLOCK_N,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+
+
+def reference_attention(
+ q,
+ k,
+ v,
+ position_ids_q, # 1D packed like q (only NEW tokens)
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ req_to_token_indexs,
+):
+ device = q.device
+ dtype = q.dtype
+ sum_q, Hq, D = q.shape
+ Hk = k.shape[1]
+ kv_group_num = Hq // Hk
+
+ batch = b_seq_len.shape[0]
+ out = torch.empty_like(q)
+ scale = 1.0 / math.sqrt(D)
+
+ for b in range(batch):
+ req = int(b_req_idx[b].item())
+ total_len = int(b_seq_len[b].item())
+ prompt_len = int(b_prompt_cache_len[b].item())
+ new_len = total_len - prompt_len
+
+ q_start = int(b_start_loc[b].item())
+ q_blk = q[q_start : q_start + new_len] # [M, Hq, D]
+ gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M]
+
+ # gather K/V for full request by logical pos -> mem_index
+ token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L]
+ k_blk = k[token_locs] # [L, Hk, D]
+ v_blk = v[token_locs] # [L, Hk, D]
+
+ # expand kv heads to q heads (GQA)
+ k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D]
+ v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) # [L, Hq, D]
+
+ # positions
+ q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) # [M]
+ k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) # [L]
+
+ # build allow mask:
+ # causal always
+ allow = k_pos[None, :] <= q_pos[:, None]
+
+ # full-attn only inside NEW part by gid
+ # compare only when k_pos in NEW
+ k_in_new = k_pos >= prompt_len
+ k_rel = (k_pos - prompt_len).clamp_min(0) # [L]
+ # map k_rel to gid_new, but only valid where k_in_new
+ k_gid = torch.empty((total_len,), device=device, dtype=torch.int64)
+ k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new
+ k_gid[k_in_new] = gid_new[k_rel[k_in_new]]
+
+ allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :])
+
+ # scores: [Hq, M, L]
+ q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D]
+ k_t = k_hq.permute(1, 2, 0).to(torch.float32) # [Hq, D, L]
+ scores = torch.matmul(q_t, k_t) * scale # [Hq, M, L]
+
+ neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32)
+ scores = torch.where(allow[None, :, :], scores, neg)
+
+ p = torch.softmax(scores, dim=-1).to(torch.float32) # [Hq, M, L]
+ v_t = v_hq.permute(1, 0, 2).to(torch.float32) # [Hq, L, D]
+ out_hq = torch.matmul(p, v_t) # [Hq, M, D]
+ out_blk = out_hq.permute(1, 0, 2).to(dtype) # [M, Hq, D]
+
+ out[q_start : q_start + new_len] = out_blk
+
+ return out
+
+
+def make_test_case(
+ device="cuda",
+ dtype=torch.float16,
+ batch=3,
+ Hq=8,
+ Hk=4,
+ D=64,
+ seed=0,
+ base_index=50000,
+):
+ torch.manual_seed(seed)
+
+ # prompt (cached) len and new len
+ prompt_lens = torch.randint(low=2, high=8, size=(batch,), device=device)
+ new_lens = torch.randint(low=1, high=8, size=(batch,), device=device)
+ total_lens = (prompt_lens + new_lens).to(torch.int32)
+
+ max_total_len = int(total_lens.max().item())
+ max_new_len = int(new_lens.max().item())
+
+ # packed q start
+ b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32)
+ cur = 0
+ for b in range(batch):
+ b_start_loc[b] = cur
+ cur += int(new_lens[b].item())
+ sum_q = cur
+
+ b_seq_len = total_lens
+ b_prompt_cache_len = prompt_lens.to(torch.int32)
+
+ # one req per batch
+ num_req = batch
+ b_req_idx = torch.arange(batch, device=device, dtype=torch.int32)
+
+ # global KV space large, indices not small
+ sum_kv = int(total_lens.sum().item())
+ kv_size = base_index + sum_kv + 1024
+ pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index
+
+ # Req_to_tokens [num_req, max_total_len]
+ req_to_token_indexs = torch.zeros((num_req, max_total_len), device=device, dtype=torch.int32)
+ p = 0
+ for r in range(num_req):
+ L = int(total_lens[r].item())
+ req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32)
+ p += L
+
+ # position_ids_q: only NEW tokens, packed like q
+ position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32)
+ for b in range(batch):
+ M = int(new_lens[b].item())
+ start = int(b_start_loc[b].item())
+
+ gid = torch.arange(M, device=device, dtype=torch.int32)
+
+ # make one repeated block inside NEW part to simulate image tokens
+ if M >= 4 and torch.rand((), device=device).item() > 0.3:
+ s = int(torch.randint(0, M - 2, (1,), device=device).item())
+ e = min(M, s + 3)
+ gid[s:e] = gid[s]
+
+ position_ids_q[start : start + M] = gid
+
+ q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype)
+ k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype)
+ v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype)
+ o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype)
+
+ return (
+ q,
+ k,
+ v,
+ o,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_new_len,
+ req_to_token_indexs,
+ )
+
+
+def check_once(device="cuda", dtype=torch.float16, seed=0):
+ (
+ q,
+ k,
+ v,
+ o,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_new_len,
+ req_to_token_indexs,
+ ) = make_test_case(device=device, dtype=dtype, seed=seed)
+
+ context_attention_fwd_neo(
+ q,
+ k,
+ v,
+ o,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ max_new_len,
+ req_to_token_indexs,
+ )
+
+ ref = reference_attention(
+ q,
+ k,
+ v,
+ position_ids_q,
+ b_req_idx,
+ b_start_loc,
+ b_seq_len,
+ b_prompt_cache_len,
+ req_to_token_indexs,
+ )
+
+ diff = (o - ref).abs()
+ max_abs = diff.max().item()
+ denom = ref.abs().max().item() + 1e-6
+ max_rel = max_abs / denom
+
+ print(f"seed={seed}, dtype={dtype}")
+ print(f"max_abs_error = {max_abs:.6e}")
+ print(f"max_rel_error = {max_rel:.6e}")
+ print("allclose(fp16 tol)?", torch.allclose(o, ref, atol=5e-2, rtol=5e-2))
+
+
+if __name__ == "__main__":
+ if not torch.cuda.is_available():
+ print("No CUDA, skip.")
+ else:
+ torch.cuda.synchronize()
+ check_once(dtype=torch.bfloat16, seed=0)
+ check_once(dtype=torch.bfloat16, seed=1)
+ check_once(dtype=torch.bfloat16, seed=2)
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
new file mode 100644
index 0000000000..1a3d4af73b
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
@@ -0,0 +1,191 @@
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _get_neo_position_triton(
+ b_image_start_idx: torch.Tensor,
+ b_image_thwd: torch.Tensor,
+ b_image_thwd_stride0: torch.Tensor,
+ b_image_nums: torch.Tensor,
+ b_image_start_num: torch.Tensor,
+ b_image_len: torch.Tensor,
+ position_ids: torch.Tensor,
+ position_ids_stride0: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_q_seq_len: torch.Tensor,
+ b_start_loc: torch.Tensor,
+ b_image_token_tag: torch.Tensor,
+ BLOCK_SIZE: tl.constexpr,
+) -> torch.Tensor:
+ cur_batch = tl.program_id(0)
+ cache_len = tl.load(b_ready_cache_len + cur_batch)
+ q_seq_len = tl.load(b_q_seq_len + cur_batch)
+ image_num = tl.load(b_image_nums + cur_batch)
+ image_start_num = tl.load(b_image_start_num + cur_batch)
+ start_loc = tl.load(b_start_loc + cur_batch)
+ for i in range(image_num):
+ local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
+ image_start_idx = start_loc + local_image_start_idx - cache_len
+ image_len = tl.load(b_image_len + image_start_num + i)
+ # image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
+ image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2)
+ for j in range(0, image_len, BLOCK_SIZE):
+ off = j + tl.arange(0, BLOCK_SIZE)
+ # 目前没考虑视频,所以t 恒为 0
+ t_pos = local_image_start_idx + off * 0
+ h_pos = off // image_w
+ w_pos = off % image_w
+ tl.store(
+ b_image_token_tag + off + image_start_idx,
+ True,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+ tl.store(
+ position_ids + off + image_start_idx,
+ t_pos,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+ tl.store(
+ position_ids + position_ids_stride0 + off + image_start_idx,
+ h_pos,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+ tl.store(
+ position_ids + position_ids_stride0 * 2 + off + image_start_idx,
+ w_pos,
+ mask=(off < image_len)
+ & (off + local_image_start_idx - cache_len < q_seq_len)
+ & (local_image_start_idx - cache_len + off >= 0),
+ )
+
+ for i in range(image_num):
+ local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
+ image_len = tl.load(b_image_len + image_start_num + i)
+ image_delta = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 3)
+ image_end = local_image_start_idx + image_len - cache_len
+ text_start = tl.maximum(0, image_end)
+ for j in range(text_start, q_seq_len, BLOCK_SIZE):
+ off = j + tl.arange(0, BLOCK_SIZE)
+ t_pos = tl.load(position_ids + off + start_loc, mask=(off < q_seq_len), other=0.0) + image_delta
+ h_pos = tl.load(position_ids + position_ids_stride0 + off + start_loc, mask=(off < q_seq_len), other=0.0)
+ w_pos = tl.load(
+ position_ids + position_ids_stride0 * 2 + off + start_loc, mask=(off < q_seq_len), other=0.0
+ )
+ tl.store(position_ids + off + start_loc, t_pos, mask=(off < q_seq_len))
+ tl.store(position_ids + position_ids_stride0 + off + start_loc, h_pos, mask=(off < q_seq_len))
+ tl.store(position_ids + position_ids_stride0 * 2 + off + start_loc, w_pos, mask=(off < q_seq_len))
+ return
+
+
+def get_neo_position_triton(
+ b_image_start_idx: torch.Tensor,
+ b_image_thwd: torch.Tensor,
+ b_image_nums: torch.Tensor,
+ b_image_start_num: torch.Tensor,
+ b_image_len: torch.Tensor,
+ position_ids: torch.Tensor,
+ b_ready_cache_len: torch.Tensor,
+ b_q_seq_len: torch.Tensor,
+ b_start_loc: torch.Tensor,
+ b_image_token_tag: torch.Tensor,
+) -> torch.Tensor:
+
+ batch_size = b_q_seq_len.shape[0]
+ assert batch_size == b_image_nums.shape[0]
+ grid = (batch_size,)
+ BLOCK_SIZE = 64
+ _get_neo_position_triton[grid](
+ b_image_start_idx=b_image_start_idx,
+ b_image_thwd=b_image_thwd,
+ b_image_thwd_stride0=b_image_thwd.stride(0),
+ b_image_nums=b_image_nums,
+ b_image_start_num=b_image_start_num,
+ b_image_len=b_image_len,
+ position_ids=position_ids,
+ position_ids_stride0=position_ids.stride(0),
+ b_ready_cache_len=b_ready_cache_len,
+ b_q_seq_len=b_q_seq_len,
+ b_start_loc=b_start_loc,
+ b_image_token_tag=b_image_token_tag,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+
+
+def test():
+ b_image_start_idx = torch.tensor([0, 0, 4], dtype=torch.int32, device="cuda")
+ b_image_thwd = torch.tensor([[1, 2, 2, -3], [1, 2, 2, -3], [1, 2, 4, -7]], dtype=torch.int32, device="cuda")
+ b_image_nums = torch.tensor([1, 2], dtype=torch.int32, device="cuda")
+ b_image_start_num = torch.tensor([0, 1], dtype=torch.int32, device="cuda")
+ b_image_len = torch.tensor([4, 4, 8], dtype=torch.int32, device="cuda")
+ position_ids = (
+ torch.tensor([0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda")
+ .unsqueeze(0)
+ .expand(3, -1)
+ .contiguous()
+ )
+ b_image_token_tag = torch.zeros([position_ids.size(1)], dtype=torch.bool, device="cuda")
+ position_ids[1:].zero_()
+ b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda")
+ b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda")
+ b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda")
+ get_neo_position_triton(
+ b_image_start_idx,
+ b_image_thwd,
+ b_image_nums,
+ b_image_start_num,
+ b_image_len,
+ position_ids,
+ b_ready_cache_len,
+ b_q_seq_len,
+ b_start_loc,
+ b_image_token_tag,
+ )
+
+ print(b_image_token_tag)
+ print(position_ids)
+ # old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1)
+
+ # position_ids = (
+ # torch.tensor([2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=torch.int32, device="cuda")
+ # .unsqueeze(0)
+ # .expand(3, -1)
+ # .contiguous()
+ # )
+ # b_ready_cache_len = torch.tensor([2, 2], dtype=torch.int32, device="cuda")
+ # b_q_seq_len = torch.tensor([5, 11], dtype=torch.int32, device="cuda")
+ # b_start_loc = torch.tensor([0, 5], dtype=torch.int32, device="cuda")
+
+ # get_neo_position_triton(
+ # b_image_start_idx,
+ # b_image_thwd,
+ # b_image_nums,
+ # b_image_start_num,
+ # b_image_len,
+ # position_ids,
+ # b_ready_cache_len,
+ # b_q_seq_len,
+ # b_start_loc,
+ # )
+
+ # print(f"old_value:\n{old_value}")
+ # print(f"position_ids:\n{position_ids}")
+ # assert torch.equal(old_value, position_ids)
+
+ """
+ tensor([[0, 0, 0, 0, 2, 3, 4, 0, 0, 0, 0, 2, 2, 2, 2, 4, 5, 6, 7, 8],
+ [0, 0, 1, 1, 2, 3, 4, 0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 6, 7, 8],
+ [0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8]],
+ device='cuda:0', dtype=torch.int32)
+ """
+
+
+if __name__ == "__main__":
+ test()
diff --git a/lightllm/models/neo_chat_moe/vision_process.py b/lightllm/models/neo_chat_moe/vision_process.py
new file mode 100644
index 0000000000..fbd57a5e9c
--- /dev/null
+++ b/lightllm/models/neo_chat_moe/vision_process.py
@@ -0,0 +1,141 @@
+import re
+import math
+import torch
+import string
+import numpy as np
+import pandas as pd
+from PIL import Image
+import torch.distributed as dist
+import torchvision.transforms as T
+
+IMAGENET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_STD = (0.229, 0.224, 0.225)
+
+
+def round_by_factor(number: int, factor: int) -> int:
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
+ return round(number / factor) * factor
+
+
+def ceil_by_factor(number: int, factor: int) -> int:
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
+ return math.ceil(number / factor) * factor
+
+
+def floor_by_factor(number: int, factor: int) -> int:
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
+ return math.floor(number / factor) * factor
+
+
+# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py#L60
+def smart_resize(
+ height: int, width: int, factor: int = 32, min_pixels: int = 65536, max_pixels: int = 4194304
+) -> tuple[int, int]:
+ """
+ Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+ """
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than {200}, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = max(factor, round_by_factor(height, factor))
+ w_bar = max(factor, round_by_factor(width, factor))
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = max(factor, floor_by_factor(height / beta, factor))
+ w_bar = max(factor, floor_by_factor(width / beta, factor))
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = ceil_by_factor(height * beta, factor)
+ w_bar = ceil_by_factor(width * beta, factor)
+ return h_bar, w_bar
+
+
+def dynamic_preprocess_native_resolution(image, size_factor=32, min_pixels=65536, max_pixels=4194304, **kwargs):
+ width, height = image.size
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=size_factor,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ )
+ image = image.resize((resized_width, resized_height))
+
+ return image
+
+
+def preprocess_pixel_values(pixel_values, patch_size=16):
+ c, h, w = pixel_values.shape
+ grid_h = h // patch_size
+ grid_w = w // patch_size
+
+ flatten_pixel_values = (
+ pixel_values.view(c, grid_h, patch_size, grid_w, patch_size)
+ .permute(1, 3, 0, 2, 4) # [grid_h, grid_w, c, patch_size, patch_size]
+ .reshape(grid_h * grid_w, c * patch_size ** 2)
+ )
+
+ grid_hw = torch.tensor([[grid_h, grid_w]]).to(device=pixel_values.device)
+
+ return flatten_pixel_values, grid_hw
+
+
+def get_contrasting_background(image):
+ """
+ Calculate the color (white or black) that is different from the average foreground color
+ to use as the background color
+ """
+ image_np = np.array(image)
+ if (image_np[:, :, 3] == 0).any():
+ non_transparent_pixels = image_np[:, :, :3][image_np[:, :, 3] > 0]
+ if non_transparent_pixels.size == 0:
+ return None
+ pixel_mean = non_transparent_pixels.mean()
+ contrasting_color = (0, 0, 0) if pixel_mean > 382.5 else (255, 255, 255)
+ return contrasting_color
+ else:
+ return None
+
+
+def load_image_native(image, patch_size=16, downsample_ratio=0.5, min_pixels=65536, max_pixels=4194304, upscale=False):
+ """
+ Load and preprocess an image file, converting it to RGB mode,
+ resizing, normalizing, and optionally adding a thumbnail version.
+ """
+ if image.mode == "RGBA":
+ bg_color = get_contrasting_background(image)
+ if bg_color:
+ background = Image.new("RGB", image.size, bg_color)
+ background.paste(image, mask=image.split()[3])
+ image = background.convert("RGB")
+ else:
+ image = image.convert("RGB")
+ else:
+ image = image.convert("RGB")
+
+ if upscale:
+ image = image.resize((image.width * 2, image.height * 2), Image.BILINEAR)
+
+ transform = T.Compose(
+ [
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
+ T.ToTensor(),
+ T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
+ ]
+ )
+
+ new_image = dynamic_preprocess_native_resolution(
+ image, size_factor=int(patch_size // downsample_ratio), min_pixels=min_pixels, max_pixels=max_pixels
+ )
+ pixel_values, grid_hw = preprocess_pixel_values(transform(new_image).to(torch.float32), patch_size=patch_size)
+
+ # print(f"Transfer image_size from ({image.height, image.width}) to ({new_image.height, new_image.width})")
+
+ return pixel_values, grid_hw
diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py
index 6355ac2dbf..a1f71cf02a 100644
--- a/lightllm/server/visualserver/model_infer/model_rpc.py
+++ b/lightllm/server/visualserver/model_infer/model_rpc.py
@@ -20,6 +20,7 @@
from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel
from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel
from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel
+from lightllm.models.neo_chat_moe.neo_visual import NeoVisionTransformerPretrainedModel
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.dist_utils import init_vision_distributed_env
from lightllm.utils.graceful_utils import graceful_registry
@@ -90,6 +91,8 @@ def exposed_init_model(self, kvargs):
.eval()
.bfloat16()
)
+ elif self.model_type == "neo_chat":
+ self.model = NeoVisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
else:
raise Exception(f"can not support {self.model_type} now")
From 0272ec39aeb62557917a01ebcc299c30299102bf Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Fri, 27 Mar 2026 07:59:37 +0000
Subject: [PATCH 02/41] neo rope
---
lightllm/models/llama/model.py | 98 ++++++++++---------
.../layer_infer/transformer_layer_infer.py | 1 +
2 files changed, 52 insertions(+), 47 deletions(-)
diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py
index 1c0277e59b..20d6cad743 100644
--- a/lightllm/models/llama/model.py
+++ b/lightllm/models/llama/model.py
@@ -73,13 +73,20 @@ def _init_custom(self):
"""
rope_scaling = self.config.get("rope_scaling", None)
if rope_scaling is None:
- scaling_type = "default"
+ self._init_to_get_rotary()
elif "rope_type" in rope_scaling:
scaling_type = rope_scaling["rope_type"]
+ self._init_rotary_by_scaling_type(scaling_type, rope_scaling)
elif "type" in rope_scaling:
scaling_type = rope_scaling["type"]
+ self._init_rotary_by_scaling_type(scaling_type, rope_scaling)
else:
raise ValueError(f"Unknown RoPE scaling format {rope_scaling}")
+ if "rope_theta_hw" in self.config:
+ self._init_to_get_hw_rotary()
+ super()._init_custom()
+
+ def _init_rotary_by_scaling_type(self, scaling_type, rope_scaling):
if scaling_type == "default" or "mrope_section" in rope_scaling:
self._init_to_get_rotary()
elif scaling_type == "yarn":
@@ -94,9 +101,6 @@ def _init_custom(self):
self._init_to_get_rotary()
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
- if "rope_theta_hw" in self.config:
- self._init_to_get_hw_rotary()
- return
def _init_to_get_rotary(self, default_base=10000):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
@@ -106,7 +110,6 @@ def _init_to_get_rotary(self, default_base=10000):
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
base = self.config.get("rope_theta", float(default_base))
-
if "max_sequence_length" in self.config:
max_seq_len = self.config["max_sequence_length"]
else:
@@ -126,9 +129,10 @@ def _init_to_get_rotary(self, default_base=10000):
except:
pass
- inv_freq = 1.0 / (
+ full_inv_freq = 1.0 / (
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
)
+ inv_freq = full_inv_freq[::2] # for neo
t = (
torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
/ rope_scaling_factor
@@ -139,6 +143,47 @@ def _init_to_get_rotary(self, default_base=10000):
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
+ def _init_to_get_hw_rotary(self, default_base=10000):
+ partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2)
+ if self.config.get("rope_scaling", {}) is None:
+ rope_scaling_factor = 1.0
+ else:
+ rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
+
+ base = self.config.get("rope_theta_hw", float(default_base))
+ if "max_sequence_length" in self.config:
+ max_seq_len = self.config["max_sequence_length"]
+ else:
+ max_position_embeddings = self.config.get(
+ "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384
+ )
+ max_seq_len = max_position_embeddings * rope_scaling_factor
+
+ # NTK
+ try:
+ ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
+ assert ntk_alpha >= 1
+ if ntk_alpha > 1:
+ logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
+ max_seq_len *= ntk_alpha
+ base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
+ except:
+ pass
+
+ full_inv_freq = 1.0 / (
+ base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
+ )
+ inv_freq = full_inv_freq[::2]
+ t = (
+ torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
+ / rope_scaling_factor
+ )
+ freqs = torch.outer(t, inv_freq)
+
+ self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda()
+ self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda()
+ return
+
def _init_to_get_dynamic_ntk_rotary(self):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
@@ -301,44 +346,3 @@ def _init_to_get_llama3_rotary(self, default_base=10000):
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
-
- def _init_to_get_hw_rotary(self, default_base=10000):
- partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_ // 2)
- if self.config.get("rope_scaling", {}) is None:
- rope_scaling_factor = 1.0
- else:
- rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
-
- base = self.config.get("rope_theta_hw", float(default_base))
- if "max_sequence_length" in self.config:
- max_seq_len = self.config["max_sequence_length"]
- else:
- max_position_embeddings = self.config.get(
- "max_position_embeddings_hw", 2048 if base <= 10000.0 + 1e-5 else 16384
- )
- max_seq_len = max_position_embeddings * rope_scaling_factor
-
- # NTK
- try:
- ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
- assert ntk_alpha >= 1
- if ntk_alpha > 1:
- logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
- max_seq_len *= ntk_alpha
- base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
- except:
- pass
-
- full_inv_freq = 1.0 / (
- base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
- )
- inv_freq = full_inv_freq[::2]
- t = (
- torch.arange(max(max_seq_len + 1024 * 128, self.max_seq_length), device="cpu", dtype=torch.float32)
- / rope_scaling_factor
- )
- freqs = torch.outer(t, inv_freq)
-
- self._hw_cos_cached = torch.cos(freqs).to(self.data_type).cuda()
- self._hw_sin_cached = torch.sin(freqs).to(self.data_type).cuda()
- return
diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
index ec181a0b8d..4517b5688a 100644
--- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
@@ -42,6 +42,7 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC
q_hw_2d = q_hw.reshape(q.shape[0], -1)
k_t_2d = k_t.reshape(k.shape[0], -1)
k_hw_2d = k_hw.reshape(k.shape[0], -1)
+
layer_weight.qk_norm_weight_(q_t_2d, k_t_2d, eps=self.eps_)
layer_weight.qk_hw_norm_weight_(q_hw_2d, k_hw_2d, eps=self.eps_)
From 10f760d30d5a0ce81fd9cb34de5be5ff0eafb725 Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Mon, 23 Mar 2026 08:35:51 +0000
Subject: [PATCH 03/41] support x2i.
---
.../triton_kernel/kv_cache_offload.py | 305 ++++++++++++++++++
lightllm/server/api_cli.py | 5 +
lightllm/server/api_http.py | 17 +
lightllm/server/api_lightllm.py | 16 +
lightllm/server/api_start.py | 24 +-
lightllm/server/core/objs/__init__.py | 1 +
lightllm/server/core/objs/req.py | 7 +-
lightllm/server/core/objs/sampling_params.py | 6 +
lightllm/server/core/objs/start_args_type.py | 3 +
.../core/objs/token_chunck_hash_list.py | 11 +
lightllm/server/core/objs/x2i_params.py | 80 +++++
lightllm/server/httpserver/manager.py | 124 +++++++
.../server/router/model_infer/infer_batch.py | 3 +
.../model_infer/mode_backend/base_backend.py | 11 +
.../model_infer/mode_backend/past_kv_cache.py | 153 +++++++++
lightllm/server/x2i_server/__init__.py | 0
lightllm/server/x2i_server/manager.py | 141 ++++++++
.../server/x2i_server/past_kv_cache_client.py | 138 ++++++++
lightllm/utils/kv_cache_utils.py | 2 +-
.../kv_trans_kernel/test_kv_trans_from_gpu.py | 212 ++++++++++++
20 files changed, 1254 insertions(+), 5 deletions(-)
create mode 100644 lightllm/server/core/objs/x2i_params.py
create mode 100644 lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
create mode 100644 lightllm/server/x2i_server/__init__.py
create mode 100644 lightllm/server/x2i_server/manager.py
create mode 100644 lightllm/server/x2i_server/past_kv_cache_client.py
create mode 100644 unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py
diff --git a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
index 0fdc43ab9f..686c0c1790 100644
--- a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
+++ b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
@@ -704,3 +704,308 @@ def load_cpu_kv_to_gpu(
num_stages=1,
)
return
+
+
+
+@triton.jit
+def _offload_gpu_kv_to_cpu_for_x2i(
+ token_indexes_ptr,
+ gpu_kv_cache_ptr,
+ gpu_stride0,
+ gpu_stride1,
+ gpu_stride2,
+ gpu_kv_cache_scale_ptr,
+ gpu_scale_stride0,
+ gpu_scale_stride1,
+ gpu_scale_stride2,
+ cpu_kv_cache_ptr,
+ cpu_stride0,
+ cpu_stride1,
+ cpu_stride2,
+ cpu_stride3,
+ cpu_kv_cache_scale_ptr,
+ cpu_scale_stride0,
+ cpu_scale_stride1,
+ cpu_scale_stride2,
+ cpu_scale_stride3,
+ page_indexes_ptr,
+ layer_num,
+ head_dim,
+ scale_head_dim,
+ block_num,
+ token_num,
+ cpu_k_start_head_index: tl.constexpr,
+ cpu_k_head_num: tl.constexpr,
+ gpu_k_start_head_index: tl.constexpr,
+ gpu_k_head_num: tl.constexpr,
+ cpu_v_start_head_index: tl.constexpr,
+ cpu_v_head_num: tl.constexpr,
+ gpu_v_start_head_index: tl.constexpr,
+ gpu_v_head_num: tl.constexpr,
+ BLOCK_HEAD_DIM: tl.constexpr,
+ TOKEN_BLOCK: tl.constexpr,
+ HAS_SCALE: tl.constexpr,
+):
+ block_start_index = tl.program_id(0)
+ block_split_size = tl.num_programs(axis=0)
+
+ for block_index in tl.range(block_start_index, block_num, block_split_size):
+ cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64)
+ token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK)
+ token_range_mask = token_range < token_num
+ token_indexes = tl.load(token_indexes_ptr + token_range, mask=token_range_mask).to(tl.int64)
+ head_dim_range = tl.arange(0, BLOCK_HEAD_DIM)
+ head_dim_mask = head_dim_range < head_dim
+ scale_head_dim_mask = head_dim_range < scale_head_dim
+
+ token_head_mask = token_range_mask[:, None] & head_dim_mask[None, :]
+ token_scale_mask = token_range_mask[:, None] & scale_head_dim_mask[None, :]
+ for layer_index in range(layer_num):
+ for k_head_index in range(gpu_k_head_num):
+ gpu_k_head_index = k_head_index + gpu_k_start_head_index
+ cpu_k_head_index = k_head_index + cpu_k_start_head_index
+
+ gpu_ptr = (
+ gpu_kv_cache_ptr
+ + layer_index.to(tl.int64) * gpu_stride0
+ + token_indexes[:, None] * gpu_stride1
+ + gpu_k_head_index.to(tl.int64) * gpu_stride2
+ + head_dim_range[None, :]
+ )
+ gpu_data = tl.load(gpu_ptr, mask=token_head_mask, other=0.0)
+ cpu_ptr = (
+ cpu_kv_cache_ptr
+ + cpu_page_index * cpu_stride0
+ + layer_index.to(tl.int64) * cpu_stride1
+ + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
+ + cpu_k_head_index * cpu_stride3
+ + head_dim_range[None, :]
+ )
+ tl.store(cpu_ptr, gpu_data, mask=token_head_mask, cache_modifier=".wt")
+
+ if HAS_SCALE:
+ gpu_scale_ptr = (
+ gpu_kv_cache_scale_ptr
+ + layer_index.to(tl.int64) * gpu_scale_stride0
+ + token_indexes[:, None] * gpu_scale_stride1
+ + gpu_k_head_index.to(tl.int64) * gpu_scale_stride2
+ + head_dim_range[None, :]
+ )
+ gpu_scale_data = tl.load(gpu_scale_ptr, mask=token_scale_mask, other=0.0)
+ cpu_scale_ptr = (
+ cpu_kv_cache_scale_ptr
+ + cpu_page_index * cpu_scale_stride0
+ + layer_index.to(tl.int64) * cpu_scale_stride1
+ + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_scale_stride2
+ + cpu_k_head_index * cpu_scale_stride3
+ + head_dim_range[None, :]
+ )
+ tl.store(cpu_scale_ptr, gpu_scale_data, mask=token_scale_mask, cache_modifier=".wt",)
+
+
+ for v_head_index in range(gpu_v_head_num):
+ gpu_v_head_index = v_head_index + gpu_v_start_head_index
+ cpu_v_head_index = v_head_index + cpu_v_start_head_index
+
+ gpu_ptr = (
+ gpu_kv_cache_ptr
+ + layer_index.to(tl.int64) * gpu_stride0
+ + token_indexes[:, None] * gpu_stride1
+ + gpu_v_head_index.to(tl.int64) * gpu_stride2
+ + head_dim_range[None, :]
+ )
+ gpu_data = tl.load(gpu_ptr, mask=token_head_mask, other=0.0)
+ cpu_ptr = (
+ cpu_kv_cache_ptr
+ + cpu_page_index * cpu_stride0
+ + layer_index.to(tl.int64) * cpu_stride1
+ + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
+ + cpu_v_head_index * cpu_stride3
+ + head_dim_range[None, :]
+ )
+ tl.store(cpu_ptr, gpu_data, mask=token_head_mask, cache_modifier=".wt")
+
+ if HAS_SCALE:
+ gpu_scale_ptr = (
+ gpu_kv_cache_scale_ptr
+ + layer_index.to(tl.int64) * gpu_scale_stride0
+ + token_indexes[:, None] * gpu_scale_stride1
+ + gpu_v_head_index.to(tl.int64) * gpu_scale_stride2
+ + head_dim_range[None, :]
+ )
+ gpu_scale_data = tl.load(gpu_scale_ptr, mask=token_scale_mask, other=0.0)
+ cpu_scale_ptr = (
+ cpu_kv_cache_scale_ptr
+ + cpu_page_index * cpu_scale_stride0
+ + layer_index.to(tl.int64) * cpu_scale_stride1
+ + tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_scale_stride2
+ + cpu_v_head_index * cpu_scale_stride3
+ + head_dim_range[None, :]
+ )
+ tl.store(cpu_scale_ptr, gpu_scale_data, mask=token_scale_mask, cache_modifier=".wt",)
+
+
+
+@torch.no_grad()
+def offload_gpu_kv_to_cpu_for_x2i(
+ token_indexes: torch.Tensor,
+ gpu_kv_cache: torch.Tensor,
+ gpu_kv_cache_scale: Optional[torch.Tensor],
+ cpu_kv_cache: torch.Tensor,
+ cpu_kv_cache_scale: Optional[torch.Tensor],
+ page_indexes: torch.Tensor,
+ tp_index: int,
+ tp_world_size: int,
+ grid_num: int,
+ _cache_data={},
+):
+ """
+ Args:
+ token_indexes: (token_num, )
+ gpu_kv_cache: (layer_num, token_num, head_num, head_dim)
+ cpu_kv_cache: (all_page_num, layer_num, token_block_size, head_num, head_dim)
+ page_indexes: (page_num,)
+ """
+
+ token_block_size = cpu_kv_cache.shape[2]
+ token_num = token_indexes.shape[0]
+ assert token_num <= page_indexes.shape[0] * token_block_size
+
+ gpu_heads = gpu_kv_cache.shape[2]
+ gpu_head_dim = gpu_kv_cache.shape[3]
+ cpu_heads = cpu_kv_cache.shape[3]
+ cpu_head_dim = cpu_kv_cache.shape[4]
+
+ assert gpu_head_dim == cpu_head_dim
+ assert gpu_kv_cache.shape[0] == cpu_kv_cache.shape[1]
+
+ scale_size = (tp_world_size * gpu_heads) // cpu_heads
+
+ if (gpu_heads, cpu_heads, tp_index, tp_world_size) in _cache_data:
+ need_offload, head_info_tuple = _cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)]
+ else:
+ if cpu_heads > 1:
+ assert (tp_world_size * gpu_heads) % cpu_heads == 0
+ assert cpu_heads % 2 == 0
+ cpu_heads_index = (
+ torch.arange(0, cpu_heads, device="cpu", dtype=torch.int32)
+ .view(cpu_heads, 1)
+ .tile((1, scale_size))
+ .view(2, tp_world_size, -1)
+ )
+ k_cpu_heads_index = cpu_heads_index[0][tp_index]
+ v_cpu_heads_index = cpu_heads_index[1][tp_index]
+
+ cpu_heads_index = torch.cat([k_cpu_heads_index, v_cpu_heads_index], dim=0).view(2, -1).numpy()
+ gpu_heads_index = torch.arange(0, gpu_heads, device="cpu", dtype=torch.int32).view(2, -1)
+
+ need_offload = tp_index % scale_size == 0
+
+ cpu_k_start_head_index = int(cpu_heads_index[0, 0])
+ cpu_k_head_num = len(cpu_heads_index[0])
+ gpu_k_start_head_index = int(gpu_heads_index[0, 0])
+ gpu_k_head_num = len(gpu_heads_index[0])
+ assert cpu_k_head_num == gpu_k_head_num
+ cpu_v_start_head_index = int(cpu_heads_index[1, 0])
+ cpu_v_head_num = len(cpu_heads_index[1])
+ gpu_v_start_head_index = int(gpu_heads_index[1, 0])
+ gpu_v_head_num = len(gpu_heads_index[1])
+ assert cpu_v_head_num == gpu_v_head_num
+
+ else:
+ assert gpu_heads == 1
+ assert cpu_heads == 1
+
+ need_offload == tp_index == 0
+ cpu_k_start_head_index = 0
+ cpu_k_head_num = 1
+ gpu_k_start_head_index = 0
+ gpu_k_head_num = 1
+ cpu_v_start_head_index = 0
+ cpu_v_head_num = 0
+ gpu_v_start_head_index = 0
+ gpu_v_head_num = 0
+
+ head_info_tuple = (
+ cpu_k_start_head_index,
+ cpu_k_head_num,
+ gpu_k_start_head_index,
+ gpu_k_head_num,
+ cpu_v_start_head_index,
+ cpu_v_head_num,
+ gpu_v_start_head_index,
+ gpu_v_head_num,
+ )
+ _cache_data[(gpu_heads, cpu_heads, tp_index, tp_world_size)] = (need_offload, head_info_tuple)
+
+ if not need_offload:
+ return
+
+ (
+ cpu_k_start_head_index,
+ cpu_k_head_num,
+ gpu_k_start_head_index,
+ gpu_k_head_num,
+ cpu_v_start_head_index,
+ cpu_v_head_num,
+ gpu_v_start_head_index,
+ gpu_v_head_num,
+ ) = head_info_tuple
+
+ assert token_block_size == triton.next_power_of_2(token_block_size)
+
+ page_num = page_indexes.shape[0]
+ grid = (grid_num, )
+ num_warps = 4
+ num_stages = 1
+ HAS_SCALE = gpu_kv_cache_scale is not None and cpu_kv_cache_scale is not None
+ if HAS_SCALE:
+ scale_head_dim = gpu_kv_cache_scale.shape[-1]
+ gpu_scale_stride = gpu_kv_cache_scale.stride()
+ cpu_scale_stride = cpu_kv_cache_scale.stride()
+ else:
+ scale_head_dim = 0
+ gpu_scale_stride = [0 for _ in range(5)]
+ cpu_scale_stride = [0 for _ in range(5)]
+
+
+ _offload_gpu_kv_to_cpu_for_x2i[grid](
+ token_indexes_ptr = token_indexes,
+ gpu_kv_cache_ptr = gpu_kv_cache,
+ gpu_stride0 = gpu_kv_cache.stride(0),
+ gpu_stride1 = gpu_kv_cache.stride(1),
+ gpu_stride2 = gpu_kv_cache.stride(2),
+ gpu_kv_cache_scale_ptr = gpu_kv_cache_scale,
+ gpu_scale_stride0=gpu_scale_stride[0],
+ gpu_scale_stride1=gpu_scale_stride[1],
+ gpu_scale_stride2=gpu_scale_stride[2],
+ cpu_kv_cache_ptr=cpu_kv_cache,
+ cpu_stride0=cpu_kv_cache.stride(0),
+ cpu_stride1=cpu_kv_cache.stride(1),
+ cpu_stride2=cpu_kv_cache.stride(2),
+ cpu_stride3=cpu_kv_cache.stride(3),
+ cpu_kv_cache_scale_ptr=cpu_kv_cache_scale,
+ cpu_scale_stride0=cpu_scale_stride[0],
+ cpu_scale_stride1=cpu_scale_stride[1],
+ cpu_scale_stride2=cpu_scale_stride[2],
+ cpu_scale_stride3=cpu_scale_stride[3],
+ page_indexes_ptr=page_indexes,
+ layer_num=gpu_kv_cache.shape[0],
+ head_dim=gpu_head_dim,
+ scale_head_dim=scale_head_dim,
+ block_num=page_num,
+ token_num=token_num,
+ cpu_k_start_head_index=cpu_k_start_head_index,
+ cpu_k_head_num=cpu_k_head_num,
+ gpu_k_start_head_index=gpu_k_start_head_index,
+ gpu_k_head_num=gpu_k_head_num,
+ cpu_v_start_head_index=cpu_v_start_head_index,
+ cpu_v_head_num=cpu_v_head_num,
+ gpu_v_start_head_index=gpu_v_start_head_index,
+ gpu_v_head_num=gpu_v_head_num,
+ BLOCK_HEAD_DIM=triton.next_power_of_2(gpu_head_dim),
+ TOKEN_BLOCK=token_block_size,
+ HAS_SCALE=HAS_SCALE,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index d32da8097c..06d8c57320 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -302,6 +302,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=None,
help="if the model is a multimodal model, set to not load audio part model.",
)
+ parser.add_argument(
+ "--enable_multimodal_x2i",
+ action="store_true",
+ help="Whether or not to allow to generate images (requird --enable_multimodal)."
+ )
parser.add_argument(
"--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service."
)
diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py
index 230da5b369..4c366fb770 100755
--- a/lightllm/server/api_http.py
+++ b/lightllm/server/api_http.py
@@ -70,6 +70,7 @@ class G_Objs:
args: StartArgs = None
g_generate_func: Callable = None
g_generate_stream_func: Callable = None
+ g_generate_image_func: Callable = None
httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None
shared_token_load: TokenLoad = None
@@ -85,6 +86,10 @@ def set_args(self, args: StartArgs):
self.g_generate_func = lightllm_generate
self.g_generate_stream_func = lightllm_generate_stream
+ if args.enable_multimodal_x2i:
+ from .api_lightllm import lightllm_generate_image
+ self.g_generate_image_func = lightllm_generate_image
+
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::api_server")
if args.run_mode == "pd_master":
@@ -257,6 +262,18 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo
resp = await completions_impl(request, raw_request)
return resp
+@app.post("/generate_image")
+async def generate_image(request: Request) -> Response:
+ if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
+ return create_error_response(
+ HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface"
+ )
+
+ try:
+ return await g_objs.g_generate_image_func(request, g_objs.httpserver_manager)
+ except Exception as e:
+ return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
+
@app.get("/tokens")
@app.post("/tokens")
diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py
index d3592a5f54..b20670e47c 100644
--- a/lightllm/server/api_lightllm.py
+++ b/lightllm/server/api_lightllm.py
@@ -3,6 +3,7 @@
from fastapi import BackgroundTasks, Request
from fastapi.responses import Response, StreamingResponse
from lightllm.server.core.objs.sampling_params import SamplingParams
+from lightllm.server.core.objs.x2i_params import X2IParams
from .multimodal_params import MultimodalParams
from .httpserver.manager import HttpServerManager
import ujson as json
@@ -150,3 +151,18 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
background_tasks = BackgroundTasks()
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
+
+
+async def lightllm_generate_image(request: Request, httpserver_manager: HttpServerManager) -> Response:
+ # 这个接口目前主要是给x2v gen用的,输入是文本,输出是图片特征
+ request_dict = await request.json()
+ prompt = request_dict.pop("inputs")
+ generation_params_dict = request_dict["parameters"]
+ generation_params = X2IParams()
+ generation_params.init(**generation_params_dict)
+ multimodal_params_dict = request_dict.get("multimodal_params", {})
+ multimodal_params = MultimodalParams(**multimodal_params_dict)
+
+ results = await httpserver_manager.generate_image(prompt, generation_params, multimodal_params, request=request)
+
+ return Response(content=json.dumps({"images": results}, ensure_ascii=False).encode("utf-8"))
\ No newline at end of file
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index 364f9ca281..dc347e4929 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -118,6 +118,8 @@ def normal_or_p_d_start(args):
if not args.disable_shm_warning:
check_recommended_shm_size(args)
+ if args.enable_multimodal_x2i:
+ args.multi_modal_x2i_cache_shm_id = uuid.uuid1().int % 123456789
assert args.zmq_mode in ["tcp://", "ipc:///tmp/"]
# 确保单机上多实列不冲突
@@ -248,7 +250,7 @@ def normal_or_p_d_start(args):
node_world_size = args.tp // args.nnodes
can_use_ports = alloc_can_use_network_port(
- num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports
+ num=12+ node_world_size + args.visual_dp * (args.visual_tp + 1), used_nccl_ports=already_uesd_ports
)
logger.info(f"alloced ports: {can_use_ports}")
(
@@ -262,8 +264,10 @@ def normal_or_p_d_start(args):
metric_port,
multi_level_kv_cache_port,
pd_decode_rpyc_port,
- ) = can_use_ports[0:10]
- can_use_ports = can_use_ports[10:]
+ x2i_port,
+ http_server_port_for_x2i,
+ ) = can_use_ports[0:12]
+ can_use_ports = can_use_ports[12:]
visual_model_tp_ports = []
visual_nccl_ports = []
@@ -288,6 +292,8 @@ def normal_or_p_d_start(args):
args.metric_port = metric_port
args.multi_level_kv_cache_port = multi_level_kv_cache_port
args.visual_nccl_ports = visual_nccl_ports
+ args.x2i_port = x2i_port
+ args.http_server_port_for_x2i = http_server_port_for_x2i
# 申请在 p d 分离模式下,会用的端口
args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size]
# p d 分离模式下用于标识节点的id
@@ -346,6 +352,18 @@ def normal_or_p_d_start(args):
],
)
+ if args.enable_multimodal_x2i:
+ from .x2i_server.manager import start_x2i_process
+
+ process_manager.start_submodule_processes(
+ start_funcs=[
+ start_x2i_process,
+ ],
+ start_args=[
+ (args,),
+ ],
+ )
+
if args.enable_cpu_cache:
from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager
diff --git a/lightllm/server/core/objs/__init__.py b/lightllm/server/core/objs/__init__.py
index ec2438de9d..06ee53be33 100644
--- a/lightllm/server/core/objs/__init__.py
+++ b/lightllm/server/core/objs/__init__.py
@@ -1,4 +1,5 @@
from .sampling_params import SamplingParams
+from .x2i_params import X2IParams
from .req import Req, FinishStatus
from .shm_req_manager import ShmReqManager
from .start_args_type import StartArgs
diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py
index 8905248bf8..3ca4b55c6a 100644
--- a/lightllm/server/core/objs/req.py
+++ b/lightllm/server/core/objs/req.py
@@ -6,7 +6,7 @@
from .sampling_params import SamplingParams
from .out_token_circlequeue import CircularQueue
from .shm_array import ShmArray
-from .token_chunck_hash_list import TokenHashList, CpuCachePageList
+from .token_chunck_hash_list import TokenHashList, CpuCachePageList, PastKVCachePageList
from lightllm.server.req_id_generator import convert_sub_id_to_group_id
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.utils.envs_utils import get_env_start_args
@@ -122,6 +122,8 @@ class Req(ctypes.Structure):
("cpu_cache_match_page_indexes", CpuCachePageList),
# 分块hash的块大小
("cpu_cache_token_page_size", ctypes.c_int),
+ # 用于图片生成场景,记录请求对应的kv cache页面信息,供生成过程使用。
+ ("past_kv_cache_page_indexes", PastKVCachePageList),
]
def get_str(self):
@@ -185,6 +187,9 @@ def init(
self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size
if get_env_start_args().enable_cpu_cache:
self._fill_input_token_hash()
+
+ if sample_param.img_gen_prefill:
+ self.past_kv_cache_page_indexes = PastKVCachePageList(self.input_len)
return
def post_init(self):
diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py
index 49b21c38fc..3cf5cb0887 100644
--- a/lightllm/server/core/objs/sampling_params.py
+++ b/lightllm/server/core/objs/sampling_params.py
@@ -322,6 +322,7 @@ class SamplingParams(ctypes.Structure):
("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True
("disable_prompt_cache", ctypes.c_bool), # whether to disable prompt cache
("seed", ctypes.c_int64), # random seed
+ ("img_gen_prefill", ctypes.c_bool), # whether to prefill for image generation, need return past key values back
]
_do_sample: bool = False
@@ -354,6 +355,8 @@ def init(self, tokenizer, **kwargs):
self.skip_special_tokens = kwargs.get("skip_special_tokens", SKIP_SPECIAL_TOKENS)
self.disable_prompt_cache = kwargs.get("disable_prompt_cache", False)
+ self.img_gen_prefill = kwargs.get("img_gen_prefill", False)
+
self.add_special_tokens = kwargs.get("add_special_tokens", True)
self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True)
self.print_eos_token = kwargs.get("print_eos_token", False)
@@ -404,6 +407,9 @@ def init(self, tokenizer, **kwargs):
self.temperature = 1.0
self.top_k = 1
+ if self.img_gen_prefill:
+ self.max_new_tokens = 1
+
self.verify()
@classmethod
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index 37c022f3a3..b741847b4c 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -160,3 +160,6 @@ class StartArgs:
metric_port: int = field(default=None)
multinode_httpmanager_port: int = field(default=12345)
multi_level_kv_cache_port: int = field(default=None)
+ x2i_port: int = field(default=None)
+ http_server_port_for_x2i: int = field(default=None)
+ enable_multimodal_x2i: bool = field(default=False)
diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py
index a79ff48a85..479c6c17ff 100644
--- a/lightllm/server/core/objs/token_chunck_hash_list.py
+++ b/lightllm/server/core/objs/token_chunck_hash_list.py
@@ -88,3 +88,14 @@ def clear(self):
def get_all(self):
return list(self.items[0 : self.size])
+
+
+class PastKVCachePageList(CpuCachePageList):
+ _pack_ = 4
+ _fields_ = CpuCachePageList._fields_ +[
+ ("token_len", ctypes.c_int), # 对应的token数量
+ ]
+
+ def __init__(self, token_len: int = 0):
+ super().__init__()
+ self.token_len = token_len
diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py
new file mode 100644
index 0000000000..66fcbe9afb
--- /dev/null
+++ b/lightllm/server/core/objs/x2i_params.py
@@ -0,0 +1,80 @@
+import ctypes
+from dataclasses import dataclass
+from typing import Dict, List
+from enum import IntEnum
+from .token_chunck_hash_list import PastKVCachePageList
+
+class CfgNormType(IntEnum):
+ NONE = 0
+ CFG_ZERO_STAR = 1
+ GLOBAL = 2
+
+
+class X2IParams(ctypes.Structure):
+ _pack_ = 4
+ _fields_ = [
+ ("width", ctypes.c_int),
+ ("height", ctypes.c_int),
+ ("steps", ctypes.c_int),
+ ("guidance_scale", ctypes.c_float),
+ ("image_guidance_scale", ctypes.c_float),
+ ("seed", ctypes.c_int),
+ ("num_images", ctypes.c_int),
+ ("cfg_norm", ctypes.c_int),
+ ("past_kvcache", PastKVCachePageList),
+ ("past_kvcache_text", PastKVCachePageList),
+ ("past_kvcache_img", PastKVCachePageList),
+ ("total_prompt_tokens", ctypes.c_int),
+ ("request_id", ctypes.c_int64),
+ ]
+
+ _width: int = 512
+ _height: int = 512
+ _steps: int = 30
+ _guidance_scale: float = 7.0
+ _image_guidance_scale: float = 7.0
+ _seed: int = 42
+ _num_images: int = 1
+ _cfg_norm: CfgNormType = CfgNormType.NONE
+
+ def init(self, **kwargs):
+ def _get(key, default):
+ v = kwargs.get(key)
+ return v if v is not None else default
+ self.width = _get("width", X2IParams._width)
+ self.height = _get("height", X2IParams._height)
+ self.steps = _get("steps", X2IParams._steps)
+ self.guidance_scale = _get("guidance_scale", X2IParams._guidance_scale)
+ self.image_guidance_scale = _get("image_guidance_scale", X2IParams._image_guidance_scale)
+ self.seed = _get("seed", X2IParams._seed)
+ self.num_images = _get("num_images", X2IParams._num_images)
+ self.cfg_norm = _get("cfg_norm", X2IParams._cfg_norm)
+ self.past_kvcache = PastKVCachePageList()
+ self.past_kvcache_text = PastKVCachePageList()
+ self.past_kvcache_img = PastKVCachePageList()
+ self.total_prompt_tokens = 0
+ self.request_id = 0
+
+ def update(self, past_kv: PastKVCachePageList, meta: Dict):
+ past_kv.token_len = meta.get("prompt_tokens")
+ past_kv.fill(meta.get("kv_cache_pages"))
+ self.total_prompt_tokens += past_kv.token_len
+
+ def update_t2i(self, meta, meta_uncond):
+ self.update(self.past_kvcache, meta)
+ self.update(self.past_kvcache_text, meta_uncond)
+
+ def update_it2i(self, meta, meta_text_uncond, meta_img_uncond):
+ self.update(self.past_kvcache, meta)
+ self.update(self.past_kvcache_text, meta_text_uncond)
+ self.update(self.past_kvcache_img, meta_img_uncond)
+
+
+@dataclass
+class X2IResponse:
+ request_id: int
+ images: List[bytes]
+
+@dataclass
+class X2ICacheRelease:
+ request_id: int
\ No newline at end of file
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index e28e4c93ad..1c943e4937 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -10,6 +10,7 @@
import hashlib
import datetime
import pickle
+import re
from frozendict import frozendict
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -24,6 +25,7 @@
from .async_queue import AsyncQueue
from lightllm.server.core.objs import Req, FinishStatus, StartArgs
from lightllm.server.core.objs import SamplingParams
+from lightllm.server.core.objs.x2i_params import X2IParams, X2ICacheRelease, X2IResponse
from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
from lightllm.server.core.objs.io_objs import GroupReqObjs
from lightllm.server.core.objs.shm_req_manager import ShmReqManager
@@ -95,6 +97,15 @@ def __init__(
self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH)
self.send_to_multi_level_kv_cache.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}")
+ if args.enable_multimodal_x2i:
+ from lightllm.server.x2i_server.past_kv_cache_client import PastKVCacheClient
+ self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=True, init_shm_data=False)
+ self.send_to_x2i = context.socket(zmq.PUSH)
+ self.send_to_x2i.connect(f"{args.zmq_mode}127.0.0.1:{args.x2i_port}")
+ self.recv_from_x2i = context.socket(zmq.PULL)
+ self.recv_from_x2i.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}")
+ self.req_id_to_x2i_reqs: Dict[int, X2IReqStatus] = {}
+
self.shm_req_manager = ShmReqManager()
# recv from detokenization
@@ -354,6 +365,15 @@ async def generate(
self.tokenizer,
chunked_prefill_size=self.args.chunked_prefill_size,
)
+ if sampling_params.img_gen_prefill:
+ # allocate pages, may block if cache is full, but it won't cause deadlock
+ # because the prefill process is designed to be sequential and the pages
+ # will be released after prefill.
+ kv_pages = self.past_kv_cache_client.allocate_pages(
+ req_obj.request_id, req_obj.input_len)
+
+ req_obj.past_kv_cache_page_indexes.fill(kv_pages)
+
req_objs.append(req_obj)
logger.debug(
@@ -404,10 +424,82 @@ async def generate(
# 进行回收。
if group_request_id not in self.req_id_to_out_inf:
await self._release_multimodal_resources(multimodal_params)
+
+ if sampling_params.img_gen_prefill:
+ # 预分配了 kv cache 的请求,在异常情况下需要主动释放 kv cache 资源
+ self.past_kv_cache_client.free_pages_by_req_id(group_request_id)
+
await self.abort(group_request_id)
+
raise e
return
+
+ async def generate_image(self, prompt: str, generation_params: X2IParams, multimodal_params: MultimodalParams, request: Request):
+ generate_req_ids = []
+ async def generation_wrapper(prompt, sample, multimodal, request):
+ async for sub_req_id, _, metadata, finish_status in self.generate(
+ prompt, sample, multimodal, request
+ ):
+ kv_cache_pages = self.past_kv_cache_client.get_pages_by_req_id(sub_req_id)
+ if kv_cache_pages is None:
+ raise Exception(f"kv_cache_pages is None for sub_req_id {sub_req_id}")
+ metadata["kv_cache_pages"] = kv_cache_pages
+ metadata["request_id"] = sub_req_id
+ metadata["finish_status"] = finish_status
+ generate_req_ids.append(sub_req_id)
+ return metadata
+
+ try:
+ # 1. construct 3 or 2 images based on the multimodel_parmas
+ sample_params = SamplingParams()
+ sample_params.init(self.tokenizer, **{"img_gen_prefill": True})
+ img_len = len(multimodal_params.images)
+
+ if img_len > 0:
+ # call it2i
+ prompt_condition = f"Please generate an image based on the following instruction: {prompt}"
+ prompt_text_uncondition = "Please generate an image based on the following instruction: "+ '\n' * img_len
+ prompt_img_uncondition = "Please generate an image based on the following instruction: " + re.sub(r"\n?", "", prompt)
+ (con_gen, text_uncon_gen, img_uncon_gen) = await asyncio.gather(*[
+ generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
+ generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params, request),
+ generation_wrapper(prompt_img_uncondition, sample_params, MultimodalParams(), request)])
+ generation_params.update_it2i(con_gen, text_uncon_gen, img_uncon_gen)
+ else:
+ # call t2i
+ prompt_condition = f"Please generate an image based on the following caption: {prompt}"
+ prompt_uncondition = f"Please generate an image based on the following caption: "
+ (con_gen, uncon_gen) = await asyncio.gather(*[
+ generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
+ generation_wrapper(prompt_uncondition, sample_params, multimodal_params, request)])
+ generation_params.update_t2i(con_gen, uncon_gen)
+ # use the first reqeust id as the gen image request id
+ x2i_req_id = generate_req_ids[0]
+ generation_params.request_id = x2i_req_id
+
+ req_status = X2IReqStatus(generation_params, generate_req_ids)
+ self.req_id_to_x2i_reqs[generation_params.request_id] = req_status
+
+ # send generation_params to generation server for image generation
+ await self.send_to_x2i.send_pyobj(generation_params, protocol=pickle.HIGHEST_PROTOCOL)
+
+ await req_status.event.wait()
+
+ assert req_status.response is not None
+
+ self.req_id_to_x2i_reqs.pop(x2i_req_id, None)
+
+ return req_status.response.images
+
+ except Exception as e:
+ logger.error(str(e))
+ pass
+
+ finally:
+ for req_id in generate_req_ids:
+ self.past_kv_cache_client.free_pages_by_req_id(req_id)
+
def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]:
image_tokens = 0
audio_tokens = 0
@@ -739,6 +831,28 @@ async def recycle_resource_loop(self):
)
return
+ async def loop_for_x2i(self):
+
+ while True:
+ try:
+ recv_obj = await asyncio.wait_for(self.recv_from_x2i.recv_pyobj(), timeout=0.05)
+
+ if isinstance(recv_obj, X2ICacheRelease):
+ status = self.req_id_to_x2i_reqs[recv_obj.request_id]
+ for req_id in status.req_ids:
+ self.past_kv_cache_client.free_pages_by_req_id(req_id)
+
+ elif isinstance(recv_obj, X2IResponse):
+ status = self.req_id_to_x2i_reqs[recv_obj.request_id]
+ status.response = recv_obj
+ status.event.set()
+
+ except asyncio.TimeoutError:
+ pass
+ except Exception as e:
+ logger.error(e)
+
+
async def handle_loop(self):
self.recycle_event = asyncio.Event()
asyncio.create_task(self.recycle_resource_loop())
@@ -753,6 +867,9 @@ async def handle_loop(self):
asyncio.create_task(pd_handle_loop(self))
+ if self.args.enable_multimodal_x2i:
+ asyncio.create_task(self.loop_for_x2i())
+
while True:
try:
await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05)
@@ -839,3 +956,10 @@ def can_release(self):
if not req.can_release():
return False
return True
+
+class X2IReqStatus:
+ def __init__(self, req_param: X2IParams, req_ids: List[int]):
+ self.req = X2IParams
+ self.req_ids = req_ids
+ self.event = asyncio.Event()
+ self.response: X2IResponse = None
\ No newline at end of file
diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py
index 0a83b101be..0088c60d65 100644
--- a/lightllm/server/router/model_infer/infer_batch.py
+++ b/lightllm/server/router/model_infer/infer_batch.py
@@ -357,6 +357,9 @@ def __init__(
# 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态
self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED
+ # img gen req need copy kv to cpu
+ self.past_kv_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED
+
# mtp_step 用来记录一个请求 draft模型每步需要生成的token数量
# 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量
self.mtp_step: int = get_env_start_args().mtp_step
diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py
index 8b085c45ed..12e418824b 100644
--- a/lightllm/server/router/model_infer/mode_backend/base_backend.py
+++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py
@@ -47,6 +47,7 @@
from lightllm.common.basemodel.triton_kernel.gather_token_id import scatter_token
from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet
from .multi_level_kv_cache import MultiLevelKvCacheModule
+from .past_kv_cache import PastKVCacheModule
class ModeBackend:
@@ -147,6 +148,9 @@ def init_model(self, kvargs):
if self.args.enable_multimodal:
g_infer_context.init_cpu_embed_cache_client()
+ if self.args.enable_multimodal_x2i:
+ self.past_kv_cache_module = PastKVCacheModule(self)
+
model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir)
model_kvargs = {
@@ -551,6 +555,9 @@ def _get_classed_reqs(
if self.args.enable_cpu_cache and len(g_infer_context.infer_req_ids) > 0:
self.multi_level_cache_module.update_cpu_cache_task_states()
+ if self.args.enable_multimodal_x2i and len(g_infer_context.infer_req_ids) > 0:
+ self.past_kv_cache_module.update_past_kv_cache_task_states()
+
if req_ids is None:
req_ids = g_infer_context.infer_req_ids
@@ -648,6 +655,10 @@ def _get_classed_reqs(
else:
true_finished_reqs = finished_reqs
+ if self.args.enable_multimodal_x2i:
+ true_finished_reqs = self.past_kv_cache_module.offload_finished_reqs_to_past_kv_cache(
+ finished_reqs=true_finished_reqs)
+
g_infer_context.filter_reqs(finished_reqs=true_finished_reqs)
g_infer_context.pause_reqs(wait_pause_reqs, is_master_in_dp=self.is_master_in_dp)
diff --git a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
new file mode 100644
index 0000000000..de80da5d01
--- /dev/null
+++ b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
@@ -0,0 +1,153 @@
+import torch
+import torch.distributed as dist
+from dataclasses import dataclass
+from typing import Optional, List, Deque
+from collections import deque
+from functools import lru_cache
+from lightllm.server.x2i_server.past_kv_cache_client import PastKVCacheClient
+from lightllm.server.router.model_infer.infer_batch import InferReq
+from lightllm.server.router.model_infer.infer_batch import g_infer_context
+from lightllm.utils.dist_utils import create_new_group_for_current_dp
+from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu_for_x2i
+from lightllm.server.core.objs.token_chunck_hash_list import LIGHTLLM_TOKEN_HASH_LIST_SIZE
+
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+@dataclass
+class TransTask:
+ req_obj: InferReq
+ sync_event: torch.cuda.Event
+
+
+class PastKVCacheModule(object):
+ def __init__(self, backend):
+ from .base_backend import ModeBackend
+ self.backend: ModeBackend = backend
+ self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=False)
+ self.page_index_buffer = torch.empty((LIGHTLLM_TOKEN_HASH_LIST_SIZE * 2,), dtype=torch.int32, device="cuda")
+ self.past_kv_cache_task: Deque[TransTask] = deque()
+ self.sync_task_status_group = create_new_group_for_current_dp("gloo")
+
+ @lru_cache()
+ def need_sync_compute_stream(self) -> bool:
+ """
+ fa3 在 offload 和 load kv cache 的时候,需要等待计算流完成,否则可能会概率崩溃。
+ """
+
+ model = self.backend.model
+ att_backends = [
+ model.prefill_att_backend,
+ model.decode_att_backend,
+ model.prefill_att_backend1,
+ model.decode_att_backend1,
+ ]
+ for att_backend in att_backends:
+ if att_backend is not None and "fa3" in att_backend.__class__.__name__.lower():
+ logger.info("PastKVCacheModule: need sync compute stream for fa3 backend.")
+ return True
+ logger.info("PastKVCacheModule: no need sync compute stream.")
+ return False
+
+
+ def offload_finished_reqs_to_past_kv_cache(self, finished_reqs: List[InferReq]) -> List[InferReq]:
+ """
+ Offload the finished reqs to past kv cache, and return the truly finished reqs that can be freed in infer batch.
+ """
+ true_finished_reqs = []
+ for req in finished_reqs:
+ # filter out non-img-gen reqs
+ if not req.shm_req.sample_params.img_gen_prefill:
+ true_finished_reqs.append(req)
+ continue
+
+ if req.past_kv_cache_task_status.is_finished():
+ true_finished_reqs.append(req)
+ continue
+
+ if req.past_kv_cache_task_status.is_running():
+ continue
+
+ assert req.past_kv_cache_task_status.is_not_started()
+
+ if self.need_sync_compute_stream():
+ g_infer_context.get_overlap_stream().synchronize()
+
+ trans_task = self._start_kv_cache_offload(req=req)
+ assert trans_task is not None
+ self.past_kv_cache_task.append(trans_task)
+
+
+ return true_finished_reqs
+
+ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]:
+
+ with torch.cuda.stream(g_infer_context.get_cpu_kv_cache_stream()):
+ page_indexes = torch.tensor(req.shm_req.past_kv_cache_page_indexes.get_all(), dtype=torch.int32, device='cpu', pin_memory=True)
+ num_tokens = req.shm_req.input_len
+
+ assert req.cur_kv_len >= num_tokens
+ assert num_tokens <= len(page_indexes) * self.past_kv_cache_client.token_page_size
+
+ cuda_page_indexes = self.page_index_buffer[:len(page_indexes)]
+ cuda_page_indexes.copy_(page_indexes)
+
+ token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0: num_tokens]
+ mem_manager = self.backend.model.mem_manager
+
+
+ if hasattr(mem_manager, "scale_buffer") and mem_manager.scale_buffer is not None:
+ cpu_cache_meta = self.past_kv_cache_client.kv_cache_tensor_meta
+ cpu_kv_cache = self.past_kv_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0:cpu_cache_meta.head_dim]
+ cpu_kv_cache_scale = self.past_kv_cache_client.cpu_kv_cache_tensor[
+ :, :, :, :, cpu_cache_meta.head_dim
+ ].view(mem_manager.scale_buffer.dtype)
+ gpu_kv_cache_scale = mem_manager.scale_buffer
+ else:
+ cpu_kv_cache = self.past_kv_cache_client.cpu_kv_cache_tensor
+ cpu_kv_cache_scale = None
+ gpu_kv_cache_scale = None
+
+ grid_num = 16
+ offload_gpu_kv_to_cpu_for_x2i(
+ token_indexes=token_indexes,
+ gpu_kv_cache=mem_manager.kv_buffer,
+ gpu_kv_cache_scale=gpu_kv_cache_scale,
+ cpu_kv_cache=cpu_kv_cache,
+ cpu_kv_cache_scale=cpu_kv_cache_scale,
+ page_indexes=cuda_page_indexes,
+ tp_index=self.backend.rank_in_dp,
+ tp_world_size=self.backend.dp_world_size,
+ grid_num=grid_num,
+ )
+ sync_event = torch.cuda.Event()
+ sync_event.record()
+ req.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING
+ return TransTask(
+ req_obj=req,
+ sync_event=sync_event,
+ )
+
+ def update_past_kv_cache_task_states(self):
+ trans_ok_tasks = []
+ while len(self.past_kv_cache_task) > 0:
+ task: TransTask = self.past_kv_cache_task.popleft()
+ if task.sync_event.query():
+ trans_ok_tasks.append(task)
+ else:
+ self.past_kv_cache_task.appendleft(task)
+ break
+
+ if len(trans_ok_tasks) == 0:
+ return
+
+ ok_tasks_num = torch.tensor(len(trans_ok_tasks))
+ dist.all_reduce(ok_tasks_num, op=dist.ReduceOp.MIN, group=self.sync_task_status_group)
+
+ if ok_tasks_num.item() > 0:
+ finished, unfinished = trans_ok_tasks[:ok_tasks_num.item()], trans_ok_tasks[ok_tasks_num.item():]
+ self.past_kv_cache_task.extendleft(reversed(unfinished))
+ for task in finished:
+ task.req_obj.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED
diff --git a/lightllm/server/x2i_server/__init__.py b/lightllm/server/x2i_server/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
new file mode 100644
index 0000000000..3dc82be897
--- /dev/null
+++ b/lightllm/server/x2i_server/manager.py
@@ -0,0 +1,141 @@
+import zmq
+import zmq.asyncio
+import asyncio
+import uvloop
+import inspect
+import setproctitle
+import pickle
+import torch
+from typing import List
+from lightllm.server.core.objs import StartArgs
+
+asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
+from lightllm.utils.log_utils import init_logger
+from lightllm.utils.graceful_utils import graceful_registry
+from lightllm.utils.process_check import start_parent_check_thread
+from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease
+from .past_kv_cache_client import PastKVCacheClient
+
+logger = init_logger(__name__)
+
+'''
+manage a generation service,
+1. start x2v pipelines
+2. receive generation request from http_server.
+3. call llm gen to obtain past key values
+4. call x2v to generate images and pass the key values to it
+5. return the generated images.
+'''
+
+class X2IManager:
+ def __init__(
+ self,
+ args: StartArgs,
+ ):
+ context = zmq.Context(2)
+ self.args = args
+
+ self.zmq_recv_socket = context.socket(zmq.PULL)
+ self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.x2i_port}")
+
+ self.send_to_httpserver = context.socket(zmq.PUSH)
+ self.send_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}")
+
+ self.waiting_reqs: List[X2IParams] = []
+
+ from lightllm.utils.dist_utils import set_current_device_id
+
+ set_current_device_id(torch.cuda.current_device())
+
+ self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True)
+
+ async def wait_to_model_ready(self):
+ # from lightx2v import LightX2VPipeline
+ # self.gen_pipe = LightX2VPipeline(
+ # model_path = self.args.model_dir,
+ # model_cls = self.args.model_name,
+ # task="t2i"
+ # )
+ # self.gen_pipe.create_generator(
+ # config_json = self.args.x2v_gen_model_config,
+ # )
+
+ pass
+
+ async def loop_for_fwd(self):
+ while True:
+ try:
+ if len(self.waiting_reqs) == 0:
+ await asyncio.sleep(0.01)
+ continue
+
+ x2i_param = self.waiting_reqs.pop(0)
+
+ past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i(
+ x2i_param.past_kvcache.get_all(), x2i_param.past_kvcache.token_len
+ )
+
+ past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i(
+ x2i_param.past_kvcache_text.get_all(), x2i_param.past_kvcache_text.token_len
+ )
+ is_t2i = x2i_param.past_kvcache_img.is_empty()
+
+ logger.info(f"past kv cache shape: {past_kv_cache.shape}, past_kv_cache_text shape: {past_kv_cache_text.shape}")
+
+ past_kv_cache_img = None
+ if not is_t2i: # t2i
+ past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i(
+ x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len
+ )
+
+ # release
+ self.send_to_httpserver.send_pyobj(
+ X2ICacheRelease(request_id=x2i_param.request_id),
+ protocol=pickle.HIGHEST_PROTOCOL)
+
+ # call generate images
+ self.send_to_httpserver.send_pyobj(X2IResponse(
+ request_id=x2i_param.request_id,
+ images=[]),
+ protocol=pickle.HIGHEST_PROTOCOL)
+
+ except Exception as e:
+ logger.error(e)
+
+
+ async def loop_for_netio_req(self):
+ while True:
+ try:
+ recv_req: X2IParams = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK)
+ self.waiting_reqs.append(recv_req)
+
+ except zmq.ZMQError:
+ await asyncio.sleep(0.1)
+
+ await asyncio.sleep(0.01)
+
+def start_x2i_process(args, pipe_writer):
+ # 注册graceful 退出的处理
+ graceful_registry(inspect.currentframe().f_code.co_name)
+ setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::x2i_server")
+ start_parent_check_thread()
+ try:
+ x2iserver = X2IManager(args=args,)
+ asyncio.run(x2iserver.wait_to_model_ready())
+ except Exception as e:
+ logger.exception(str(e))
+ x2iserver.clean_up()
+ raise e
+
+ pipe_writer.send("init ok")
+
+ def handle_exception(loop, context):
+ logger.exception(f"X2IServer Caught exception: {str(context)}")
+
+ loop = asyncio.new_event_loop()
+ loop.set_exception_handler(handle_exception)
+ asyncio.set_event_loop(loop)
+ loop.create_task(x2iserver.loop_for_fwd())
+ loop.run_until_complete(x2iserver.loop_for_netio_req())
+ return
diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py
new file mode 100644
index 0000000000..4a95635d1e
--- /dev/null
+++ b/lightllm/server/x2i_server/past_kv_cache_client.py
@@ -0,0 +1,138 @@
+import ctypes
+import torch
+import numpy as np
+from threading import Lock, Condition
+from dataclasses import dataclass
+from lightllm.utils.envs_utils import get_env_start_args
+from typing import List, Optional, Tuple
+from lightllm.utils.log_utils import init_logger
+from lightllm.utils.kv_cache_utils import (
+ calcu_cpu_cache_meta,
+ create_shm_kv_cache_ptr,
+ attach_shm_kv_cache_ptr,
+ register_shm_ptr_to_pin,
+)
+
+logger = init_logger(__name__)
+
+
+@dataclass
+class PastKVCacheItem:
+ req_id: int
+ token_len: int
+ page_indexes: List[int]
+
+
+class PastKVCacheClient(object):
+ """
+ This class is responsible for passing kv cache between generation server and model server,
+ and manage the shared memory for kv cache.
+ """
+
+ def __init__(self, only_create_meta_data: bool, init_shm_data: bool):
+ self.args = get_env_start_args()
+ # to do here need calcu from from settings.
+ self.kv_cache_tensor_meta = calcu_cpu_cache_meta()
+ self.page_num: int = self.kv_cache_tensor_meta.page_num
+ self.token_page_size: int = self.kv_cache_tensor_meta.token_page_size
+ self.allocated_pages_dict: dict[int, PastKVCacheItem] = {}
+ self.free_pages: List[int] = list(range(self.page_num))
+ self.lock = Lock()
+ self.cond = Condition(self.lock)
+
+ if not only_create_meta_data:
+ if init_shm_data:
+ self._create_shm_cpu_kv_cache()
+ self.attach_shm_handle = None
+ else:
+ self.attach_shm_handle = self._attach_shm_cpu_kv_cache()
+ self.attach_shm_handle.wait()
+ return
+
+ def allocate_pages(self, req_id: int, need_tokens: int) -> List[int]:
+ need_pages = (need_tokens + self.token_page_size - 1) // self.token_page_size
+ if need_pages > self.page_num:
+ logger.error(
+ f"Request {req_id} need {need_tokens} tokens, which requires {need_pages} pages, "
+ f"exceeds the total page number {self.page_num}"
+ )
+ raise ValueError(f"error allocate pages for request {req_id} with {need_tokens} tokens")
+
+ with self.cond:
+ while len(self.free_pages) < need_pages:
+ self.cond.wait()
+
+ page_indexes, self.free_pages = self.free_pages[:need_pages], self.free_pages[need_pages:]
+ self.allocated_pages_dict[req_id] = PastKVCacheItem(
+ req_id=req_id, token_len=need_tokens, page_indexes=page_indexes)
+
+ return page_indexes
+
+ def free_pages_by_req_id(self, req_id: int):
+ with self.cond:
+ item = self.allocated_pages_dict.pop(req_id, None)
+ if item is not None:
+ self.free_pages.extend(item.page_indexes)
+ self.cond.notify_all()
+
+ def get_pages_by_req_id(self, req_id: int) -> Optional[List[int]]:
+ with self.lock:
+ item = self.allocated_pages_dict.get(req_id, None)
+ return item.page_indexes if item is not None else None
+
+ def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optional[torch.Tensor]:
+ if page_indexes is None:
+ return None
+ assert token_num <= len(page_indexes) * self.token_page_size and \
+ token_num > (len(page_indexes) - 1) * self.token_page_size
+ (P, L, S, H, D) = self.cpu_kv_cache_tensor[page_indexes].shape
+ # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (2, L, H // 2, P, S, D)
+ # -> (2, L, H // 2, P * S, D) -> ( L, 2, H // 2, P * S, D)
+ kv = self.cpu_kv_cache_tensor[page_indexes] \
+ .view(P, L, S, 2, H // 2, D) \
+ .permute(3, 1, 4, 0, 2, 5).contiguous() \
+ .view(2, L, H // 2, P * S, D) \
+ .permute(1, 0, 2, 3, 4)
+ return kv[:, :, :, :token_num, :].contiguous()
+
+ def _create_shm_cpu_kv_cache(self):
+ shm_ptr = create_shm_kv_cache_ptr(
+ key=self.args.multi_modal_x2i_cache_shm_id, size=self.kv_cache_tensor_meta.calcu_size()
+ )
+ numpy_array = np.frombuffer(
+ memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8
+ )
+ # 将 NumPy 数组转换为 PyTorch 张量
+ shape = (
+ self.kv_cache_tensor_meta.page_num,
+ self.kv_cache_tensor_meta.layer_num,
+ self.kv_cache_tensor_meta.token_page_size,
+ self.kv_cache_tensor_meta.num_heads,
+ self.kv_cache_tensor_meta.get_merged_head_dim(),
+ )
+ self.cpu_kv_cache_tensor = (
+ torch.from_numpy(numpy_array).view(dtype=self.kv_cache_tensor_meta.data_type).view(shape)
+ )
+ return
+
+ def _attach_shm_cpu_kv_cache(self):
+ shm_ptr = attach_shm_kv_cache_ptr(
+ key=self.args.multi_modal_x2i_cache_shm_id, size=self.kv_cache_tensor_meta.calcu_size()
+ )
+ handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.kv_cache_tensor_meta.calcu_size())
+ numpy_array = np.frombuffer(
+ memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8
+ )
+ shape = (
+ self.kv_cache_tensor_meta.page_num,
+ self.kv_cache_tensor_meta.layer_num,
+ self.kv_cache_tensor_meta.token_page_size,
+ self.kv_cache_tensor_meta.num_heads,
+ self.kv_cache_tensor_meta.get_merged_head_dim(),
+ )
+ self.cpu_kv_cache_tensor = (
+ torch.from_numpy(numpy_array).view(dtype=self.kv_cache_tensor_meta.data_type).view(shape)
+ )
+ assert shm_ptr == self.cpu_kv_cache_tensor.data_ptr()
+
+ return handle
diff --git a/lightllm/utils/kv_cache_utils.py b/lightllm/utils/kv_cache_utils.py
index 10764e24b0..de0e3f76ca 100644
--- a/lightllm/utils/kv_cache_utils.py
+++ b/lightllm/utils/kv_cache_utils.py
@@ -59,7 +59,7 @@ def compute_token_list_hash(tokens: List[int], cpu_cache_token_page_size: int) -
@lru_cache(maxsize=None)
def calcu_cpu_cache_meta() -> "CpuKVCacheMeta":
args = get_env_start_args()
- assert args.enable_cpu_cache
+ assert args.enable_cpu_cache or args.enable_multimodal_x2i
mem_manager_class = select_mem_manager_class()
if mem_manager_class is Deepseek2MemoryManager:
diff --git a/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py b/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py
new file mode 100644
index 0000000000..ca6c5043c9
--- /dev/null
+++ b/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py
@@ -0,0 +1,212 @@
+import pytest
+import torch
+
+from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu_all
+
+# =========================================================
+# GPU guard
+pytestmark = pytest.mark.skipif(
+ not torch.cuda.is_available(),
+ reason="CUDA not available"
+)
+# =========================================================
+# 工具函数:生成 GPU KV cache
+def make_kv(L, T, H, D, device="cuda"):
+ x = torch.arange(L * T * H * D, dtype=torch.float32, device=device)
+ return x.view(L, T, H, D)
+
+
+# =========================================================
+# Test 1: 基础功能(每 token 对应一个 page slot)
+def test_basic_copy():
+ L, T, H, D = 2, 8, 4, 16
+ B = 1 # token_block_size
+ P = 4 # all_page_num
+
+ gpu_kv = make_kv(L, T, H, D)
+ cpu_kv = torch.zeros((P, L, B, H, D), dtype=torch.float32, pin_memory=True)
+
+ token_indexes = torch.tensor([1, 3, 5, 7], device="cuda")
+ page_indexes = torch.tensor([0, 1, 2, 3], device="cuda")
+
+ offload_gpu_kv_to_cpu_all(
+ token_indexes,
+ gpu_kv,
+ None,
+ cpu_kv,
+ None,
+ page_indexes,
+ tp_index=0,
+ tp_world_size=1,
+ grid_num=1,
+ )
+
+ for i in range(len(token_indexes)):
+ token = token_indexes[i].item()
+ page = page_indexes[i].item()
+
+ expected = gpu_kv[:, token, :, :].cpu() # (L,H,D)
+ actual = cpu_kv[page, :, 0, :, :] # (L,H,D)
+
+ assert torch.allclose(expected, actual)
+
+
+# =========================================================
+# Test 2: token乱序
+def test_random_tokens():
+ L, T, H, D = 2, 16, 4, 8
+ B = 1
+ P = 6
+
+ gpu_kv = make_kv(L, T, H, D)
+ cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True)
+
+ token_indexes = torch.tensor([10, 2, 7, 15, 0, 3], device="cuda")
+ page_indexes = torch.arange(6, device="cuda")
+
+ offload_gpu_kv_to_cpu_all(
+ token_indexes,
+ gpu_kv,
+ None,
+ cpu_kv,
+ None,
+ page_indexes,
+ tp_index=0,
+ tp_world_size=1,
+ grid_num=1,
+ )
+
+ for i in range(6):
+ t = token_indexes[i].item()
+ p = page_indexes[i].item()
+
+ assert torch.allclose(
+ cpu_kv[p, :, 0, :, :],
+ gpu_kv[:, t, :, :].cpu()
+ )
+
+
+# =========================================================
+# Test 3: 带 scale
+def test_with_scale():
+ L, T, H, D = 2, 8, 4, 16
+ B = 1
+ P = 3
+
+ gpu_kv = make_kv(L, T, H, D)
+ gpu_scale = torch.ones((L, T, H, D // 8), device="cuda") * 2.0
+
+ cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True)
+ cpu_scale = torch.zeros((P, L, B, H, D // 8), pin_memory=True)
+
+ token_indexes = torch.tensor([1, 2, 3], device="cuda")
+ page_indexes = torch.tensor([0, 1, 2], device="cuda")
+
+ offload_gpu_kv_to_cpu_all(
+ token_indexes,
+ gpu_kv,
+ gpu_scale,
+ cpu_kv,
+ cpu_scale,
+ page_indexes,
+ tp_index=0,
+ tp_world_size=1,
+ grid_num=1,
+ )
+
+ for i in range(3):
+ t = token_indexes[i].item()
+ p = page_indexes[i].item()
+
+ # KV
+ assert torch.allclose(
+ cpu_kv[p, :, 0, :, :],
+ gpu_kv[:, t, :, :].cpu()
+ )
+
+ # scale
+ assert torch.allclose(
+ cpu_scale[p, :, 0, :],
+ gpu_scale[:, t, :].cpu()
+ )
+
+
+# =========================================================
+# Test 4: Tensor Parallel (按 head 切)
+def test_tp_split():
+ L, T, H, D = 2, 8, 4, 16
+ B = 1
+ P = 2
+ tp_world_size = 2
+
+ gpu_kv = make_kv(L, T, H, D)
+
+ cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True)
+
+ token_indexes = torch.tensor([1, 2], device="cuda")
+ page_indexes = torch.tensor([0, 1], device="cuda")
+
+ offload_gpu_kv_to_cpu_all(
+ token_indexes, gpu_kv, None,
+ cpu_kv, None,
+ page_indexes,
+ tp_index=0,
+ tp_world_size=tp_world_size,
+ grid_num=1,
+ )
+
+ offload_gpu_kv_to_cpu_all(
+ token_indexes, gpu_kv, None,
+ cpu_kv, None,
+ page_indexes,
+ tp_index=1,
+ tp_world_size=tp_world_size,
+ grid_num=1,
+ )
+
+ split = H // tp_world_size
+
+ for i in range(2):
+ t = token_indexes[i].item()
+ p = page_indexes[i].item()
+
+ assert torch.allclose(
+ cpu_kv[p, :, 0, :split, :],
+ gpu_kv[:, t, :split, :].cpu()
+ )
+
+ assert torch.allclose(
+ cpu_kv[p, :, 0, split:, :],
+ gpu_kv[:, t, split:, :].cpu()
+ )
+
+
+# =========================================================
+# Test 5: 空输入
+def test_empty():
+ L, T, H, D = 2, 8, 4, 16
+ B = 1
+ P = 2
+
+ gpu_kv = make_kv(L, T, H, D)
+ cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True)
+
+ token_indexes = torch.tensor([], dtype=torch.long, device="cuda")
+ page_indexes = torch.tensor([], dtype=torch.long, device="cuda")
+
+ offload_gpu_kv_to_cpu_all(
+ token_indexes,
+ gpu_kv,
+ None,
+ cpu_kv,
+ None,
+ page_indexes,
+ tp_index=0,
+ tp_world_size=1,
+ grid_num=1,
+ )
+
+ assert torch.all(cpu_kv == 0)
+
+if __name__ == "__main__":
+ pytest.main()
\ No newline at end of file
From 142e3df42a2a595066727c8c6fce2297f5c7d245 Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Wed, 25 Mar 2026 07:47:18 +0000
Subject: [PATCH 04/41] add naive x2i backend.
---
lightllm/models/neo_chat_moe/model.py | 4 +-
lightllm/server/api_cli.py | 6 +
lightllm/server/api_start.py | 7 +-
lightllm/server/core/objs/start_args_type.py | 3 +
.../core/objs/token_chunck_hash_list.py | 10 +
lightllm/server/core/objs/x2i_params.py | 58 +-
lightllm/server/httpserver/manager.py | 38 +-
lightllm/server/multimodal_params.py | 13 +
lightllm/server/x2i_server/manager.py | 61 +-
.../naive/configuration_neo_chat.py | 77 ++
.../x2i_server/naive/configuration_neo_vit.py | 52 +
.../x2i_server/naive/modeling_fm_modules.py | 591 +++++++++
.../x2i_server/naive/modeling_neo_chat.py | 755 +++++++++++
.../x2i_server/naive/modeling_neo_vit.py | 235 ++++
.../server/x2i_server/naive/modeling_qwen3.py | 1117 +++++++++++++++++
.../server/x2i_server/past_kv_cache_client.py | 14 +-
16 files changed, 2995 insertions(+), 46 deletions(-)
create mode 100644 lightllm/server/x2i_server/naive/configuration_neo_chat.py
create mode 100644 lightllm/server/x2i_server/naive/configuration_neo_vit.py
create mode 100644 lightllm/server/x2i_server/naive/modeling_fm_modules.py
create mode 100644 lightllm/server/x2i_server/naive/modeling_neo_chat.py
create mode 100644 lightllm/server/x2i_server/naive/modeling_neo_vit.py
create mode 100644 lightllm/server/x2i_server/naive/modeling_qwen3.py
diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py
index d9f40d7feb..edabc06075 100644
--- a/lightllm/models/neo_chat_moe/model.py
+++ b/lightllm/models/neo_chat_moe/model.py
@@ -47,7 +47,6 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
def load_conversion_module(self, model_dir: str):
import importlib
-
conversion_path = os.path.join(model_dir, "conversation.py")
if not os.path.exists(conversion_path):
return None
@@ -155,8 +154,7 @@ def get_query_for_it2i(self, prompt: str):
def get_query_for_t2i(self, prompt):
query_condition = self._build_t2i_query(
f"Please generate an image based on the following description: {prompt}",
- thinking_content="\n\n\n\n",
- )
+ thinking_content="\n\n\n\n")
query_uncondition = self._build_t2i_query(f"")
return query_condition, query_uncondition
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 06d8c57320..4c4be78b3c 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -307,6 +307,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
action="store_true",
help="Whether or not to allow to generate images (requird --enable_multimodal)."
)
+ parser.add_argument(
+ "--x2i_server_used_gpus",
+ type=int,
+ default=1,
+ help="Number of GPUs to use for x2i server (requird --enable_multimodal_x2i).",
+ )
parser.add_argument(
"--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service."
)
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index dc347e4929..358ae425b4 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -353,8 +353,9 @@ def normal_or_p_d_start(args):
)
if args.enable_multimodal_x2i:
- from .x2i_server.manager import start_x2i_process
-
+ from .x2i_server.manager import start_x2i_process, setup_devices
+ origin_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ setup_devices(args)
process_manager.start_submodule_processes(
start_funcs=[
start_x2i_process,
@@ -363,6 +364,8 @@ def normal_or_p_d_start(args):
(args,),
],
)
+ if origin_devices:
+ os.environ["CUDA_VISIBLE_DEVICES"] = origin_devices
if args.enable_cpu_cache:
from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index b741847b4c..5cf13ef19a 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -162,4 +162,7 @@ class StartArgs:
multi_level_kv_cache_port: int = field(default=None)
x2i_port: int = field(default=None)
http_server_port_for_x2i: int = field(default=None)
+ x2i_server_used_gpus: int = field(default=1)
+
+ # multi_modal
enable_multimodal_x2i: bool = field(default=False)
diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py
index 479c6c17ff..b16f264936 100644
--- a/lightllm/server/core/objs/token_chunck_hash_list.py
+++ b/lightllm/server/core/objs/token_chunck_hash_list.py
@@ -94,8 +94,18 @@ class PastKVCachePageList(CpuCachePageList):
_pack_ = 4
_fields_ = CpuCachePageList._fields_ +[
("token_len", ctypes.c_int), # 对应的token数量
+ ("img_tokens", ctypes.c_int),
+ ("img_len", ctypes.c_int)
]
def __init__(self, token_len: int = 0):
super().__init__()
self.token_len = token_len
+ self.img_tokens = 0
+ self.img_len = 0
+
+ def get_compressed_len(self):
+ return self.token_len - self.img_tokens + self.img_len
+
+ def __repr__(self):
+ return f"(token_len={self.token_len}, img_tokens={self.img_tokens}, img_len={self.img_len})"
\ No newline at end of file
diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py
index 66fcbe9afb..50b2595ec2 100644
--- a/lightllm/server/core/objs/x2i_params.py
+++ b/lightllm/server/core/objs/x2i_params.py
@@ -1,6 +1,6 @@
import ctypes
from dataclasses import dataclass
-from typing import Dict, List
+from typing import Dict, List, Optional
from enum import IntEnum
from .token_chunck_hash_list import PastKVCachePageList
@@ -8,6 +8,21 @@ class CfgNormType(IntEnum):
NONE = 0
CFG_ZERO_STAR = 1
GLOBAL = 2
+ TEXT_CHANNEL = 3
+ CHANNEL = 4
+
+ def as_str(self) -> str:
+ mapping = {
+ CfgNormType.NONE: "none",
+ CfgNormType.CFG_ZERO_STAR: "cfg_zero_star",
+ CfgNormType.GLOBAL: "global",
+ CfgNormType.TEXT_CHANNEL: "text_channel",
+ CfgNormType.CHANNEL: "channel",
+ }
+ return mapping[self]
+
+ def __repr__(self):
+ return self.as_str()
class X2IParams(ctypes.Structure):
@@ -28,9 +43,9 @@ class X2IParams(ctypes.Structure):
("request_id", ctypes.c_int64),
]
- _width: int = 512
- _height: int = 512
- _steps: int = 30
+ _width: int = 1024
+ _height: int = 1024
+ _steps: int = 50
_guidance_scale: float = 7.0
_image_guidance_scale: float = 7.0
_seed: int = 42
@@ -56,8 +71,11 @@ def _get(key, default):
self.request_id = 0
def update(self, past_kv: PastKVCachePageList, meta: Dict):
- past_kv.token_len = meta.get("prompt_tokens")
- past_kv.fill(meta.get("kv_cache_pages"))
+ item: PastKVCacheItem = meta.get("kv_cache_item")
+ past_kv.token_len = item.token_len
+ past_kv.img_tokens = item.img_tokens
+ past_kv.img_len = item.img_len
+ past_kv.fill(item.page_indexes)
self.total_prompt_tokens += past_kv.token_len
def update_t2i(self, meta, meta_uncond):
@@ -69,12 +87,36 @@ def update_it2i(self, meta, meta_text_uncond, meta_img_uncond):
self.update(self.past_kvcache_text, meta_text_uncond)
self.update(self.past_kvcache_img, meta_img_uncond)
+ def get_cfg_norm(self):
+ return CfgNormType(self.cfg_norm).as_str()
+
+ def to_string(self):
+ parts = []
+ for field_name, _ in self._fields_:
+ value = getattr(self, field_name)
+ parts.append(f"{field_name}={value}")
+
+ return "X2IParams(" + ", ".join(parts) + ")"
+
+ def __repr__(self):
+ return self.to_string()
+
@dataclass
class X2IResponse:
request_id: int
- images: List[bytes]
+ images: Optional[List[bytes]]
+
@dataclass
class X2ICacheRelease:
- request_id: int
\ No newline at end of file
+ request_id: int
+
+
+@dataclass
+class PastKVCacheItem:
+ req_id: int
+ token_len: int
+ img_tokens: int
+ img_len: int
+ page_indexes: List[int]
\ No newline at end of file
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 1c943e4937..9502d630f0 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -25,7 +25,7 @@
from .async_queue import AsyncQueue
from lightllm.server.core.objs import Req, FinishStatus, StartArgs
from lightllm.server.core.objs import SamplingParams
-from lightllm.server.core.objs.x2i_params import X2IParams, X2ICacheRelease, X2IResponse
+from lightllm.server.core.objs.x2i_params import X2IParams, X2ICacheRelease, X2IResponse, PastKVCacheItem
from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
from lightllm.server.core.objs.io_objs import GroupReqObjs
from lightllm.server.core.objs.shm_req_manager import ShmReqManager
@@ -369,8 +369,10 @@ async def generate(
# allocate pages, may block if cache is full, but it won't cause deadlock
# because the prefill process is designed to be sequential and the pages
# will be released after prefill.
+ img_tokens = sum([img.token_num for img in multimodal_params.images])
+ img_len = len(multimodal_params.images)
kv_pages = self.past_kv_cache_client.allocate_pages(
- req_obj.request_id, req_obj.input_len)
+ req_obj.request_id, req_obj.input_len, img_tokens, img_len)
req_obj.past_kv_cache_page_indexes.fill(kv_pages)
@@ -417,7 +419,7 @@ async def generate(
yield sub_req_id, request_output, metadata, finish_status
except Exception as e:
- logger.error(f"group_request_id: {group_request_id} has exception {str(e)}")
+ logger.error(f"group_request_id: {group_request_id} has exception {str(e)}", exc_info=e)
# error need to release multimodel resources.
# 对于还没有形成正式请求对象管理的多模态资源,需要单独自己释放
# 已经放入到 req_id_to_out_inf 中的请求对象,由统一的回收循环
@@ -441,35 +443,36 @@ async def generation_wrapper(prompt, sample, multimodal, request):
async for sub_req_id, _, metadata, finish_status in self.generate(
prompt, sample, multimodal, request
):
- kv_cache_pages = self.past_kv_cache_client.get_pages_by_req_id(sub_req_id)
- if kv_cache_pages is None:
+ kv_cache_item: PastKVCacheItem = self.past_kv_cache_client.get_pages_by_req_id(sub_req_id)
+ if kv_cache_item is None:
raise Exception(f"kv_cache_pages is None for sub_req_id {sub_req_id}")
- metadata["kv_cache_pages"] = kv_cache_pages
+ metadata["kv_cache_item"] = kv_cache_item
metadata["request_id"] = sub_req_id
metadata["finish_status"] = finish_status
generate_req_ids.append(sub_req_id)
return metadata
try:
- # 1. construct 3 or 2 images based on the multimodel_parmas
+ # 1. construct 3 or 2 generate based on the multimodel_parmas
sample_params = SamplingParams()
sample_params.init(self.tokenizer, **{"img_gen_prefill": True})
img_len = len(multimodal_params.images)
if img_len > 0:
# call it2i
- prompt_condition = f"Please generate an image based on the following instruction: {prompt}"
- prompt_text_uncondition = "Please generate an image based on the following instruction: "+ '\n' * img_len
- prompt_img_uncondition = "Please generate an image based on the following instruction: " + re.sub(r"\n?", "", prompt)
+ # fix prompt, add tag if img_len greater than s in prompt
+ prompt = self.tokenizer.fix_prompt(prompt, img_len)
+
+ prompt_condition, prompt_text_uncondition, prompt_img_uncondition = self.tokenizer.get_query_for_it2i(prompt)
(con_gen, text_uncon_gen, img_uncon_gen) = await asyncio.gather(*[
generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
- generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params, request),
+ generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params.clone(), request),
generation_wrapper(prompt_img_uncondition, sample_params, MultimodalParams(), request)])
generation_params.update_it2i(con_gen, text_uncon_gen, img_uncon_gen)
else:
# call t2i
- prompt_condition = f"Please generate an image based on the following caption: {prompt}"
- prompt_uncondition = f"Please generate an image based on the following caption: "
+ prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt)
+ logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}")
(con_gen, uncon_gen) = await asyncio.gather(*[
generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
generation_wrapper(prompt_uncondition, sample_params, multimodal_params, request)])
@@ -494,7 +497,7 @@ async def generation_wrapper(prompt, sample, multimodal, request):
except Exception as e:
logger.error(str(e))
- pass
+ return []
finally:
for req_id in generate_req_ids:
@@ -844,13 +847,18 @@ async def loop_for_x2i(self):
elif isinstance(recv_obj, X2IResponse):
status = self.req_id_to_x2i_reqs[recv_obj.request_id]
+ if recv_obj.images is None:
+ for req_id in status.req_ids:
+ self.past_kv_cache_client.free_pages_by_req_id(req_id)
+
status.response = recv_obj
status.event.set()
except asyncio.TimeoutError:
pass
+
except Exception as e:
- logger.error(e)
+ logger.error(e, exc_info=e)
async def handle_loop(self):
diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py
index 09a07455b3..999dfc6859 100644
--- a/lightllm/server/multimodal_params.py
+++ b/lightllm/server/multimodal_params.py
@@ -184,3 +184,16 @@ def to_origin_dict(self):
ret["images"] = [i.to_origin_dict() for i in self.images]
ret["audios"] = [a.to_origin_dict() for a in self.audios]
return ret
+<<<<<<< HEAD
+=======
+
+ def free(self):
+ for image in self.images:
+ image.free()
+ for audio in self.audios:
+ audio.free()
+ return
+
+ def clone(self):
+ return MultimodalParams(**self.to_origin_dict())
+>>>>>>> 43a488a8 (add naive x2i backend.)
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index 3dc82be897..11e53b5d67 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -6,6 +6,7 @@
import setproctitle
import pickle
import torch
+import os
from typing import List
from lightllm.server.core.objs import StartArgs
@@ -15,10 +16,12 @@
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease
+from lightllm.utils.dist_utils import set_current_device_id
from .past_kv_cache_client import PastKVCacheClient
logger = init_logger(__name__)
+
'''
manage a generation service,
1. start x2v pipelines
@@ -44,10 +47,6 @@ def __init__(
self.waiting_reqs: List[X2IParams] = []
- from lightllm.utils.dist_utils import set_current_device_id
-
- set_current_device_id(torch.cuda.current_device())
-
self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True)
async def wait_to_model_ready(self):
@@ -61,8 +60,21 @@ async def wait_to_model_ready(self):
# config_json = self.args.x2v_gen_model_config,
# )
+ from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I
+
+ self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device())
+
pass
+ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams):
+ images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param)
+ return images
+
+ async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams):
+ images = self.naive_x2i.it2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img, param)
+ return images
+
+
async def loop_for_fwd(self):
while True:
try:
@@ -81,8 +93,6 @@ async def loop_for_fwd(self):
)
is_t2i = x2i_param.past_kvcache_img.is_empty()
- logger.info(f"past kv cache shape: {past_kv_cache.shape}, past_kv_cache_text shape: {past_kv_cache_text.shape}")
-
past_kv_cache_img = None
if not is_t2i: # t2i
past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i(
@@ -94,13 +104,24 @@ async def loop_for_fwd(self):
X2ICacheRelease(request_id=x2i_param.request_id),
protocol=pickle.HIGHEST_PROTOCOL)
- # call generate images
+ images = []
+ logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with x2i_param: {x2i_param}")
+ if is_t2i:
+ images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param)
+ else:
+ images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param)
+
self.send_to_httpserver.send_pyobj(X2IResponse(
request_id=x2i_param.request_id,
- images=[]),
+ images=images),
protocol=pickle.HIGHEST_PROTOCOL)
except Exception as e:
+ self.send_to_httpserver.send_pyobj(X2IResponse(
+ request_id=x2i_param.request_id,
+ images=None),
+ protocol=pickle.HIGHEST_PROTOCOL)
+
logger.error(e)
@@ -115,11 +136,35 @@ async def loop_for_netio_req(self):
await asyncio.sleep(0.01)
+ def clean_up(self):
+ pass
+
+def setup_devices(args: StartArgs):
+ devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
+ logger.info(f"current devices: {devices} {torch.cuda.device_count()}")
+ if not devices:
+ devices = list(range(torch.cuda.device_count()))
+ else:
+ devices = [int(x.strip()) for x in devices.split(",") if x.strip()]
+
+ llm_need_gpus = args.tp * args.dp
+ x2i_need_gpus = args.x2i_server_used_gpus
+ if len(devices) < llm_need_gpus + x2i_need_gpus:
+ raise ValueError(f"devices {devices} not enough, need {llm_need_gpus} and {x2i_need_gpus}")
+
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices[
+ llm_need_gpus:llm_need_gpus + x2i_need_gpus]))
+
+ logger.info(f"setup devices for x2i server: {os.environ['CUDA_VISIBLE_DEVICES']}, "
+ f"{torch.cuda.device_count()} {torch.cuda.current_device()}")
+
+
def start_x2i_process(args, pipe_writer):
# 注册graceful 退出的处理
graceful_registry(inspect.currentframe().f_code.co_name)
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::x2i_server")
start_parent_check_thread()
+ set_current_device_id(torch.cuda.current_device())
try:
x2iserver = X2IManager(args=args,)
asyncio.run(x2iserver.wait_to_model_ready())
diff --git a/lightllm/server/x2i_server/naive/configuration_neo_chat.py b/lightllm/server/x2i_server/naive/configuration_neo_chat.py
new file mode 100644
index 0000000000..44171917a5
--- /dev/null
+++ b/lightllm/server/x2i_server/naive/configuration_neo_chat.py
@@ -0,0 +1,77 @@
+import copy
+
+from transformers import Qwen3Config
+from transformers.utils import logging
+from transformers.configuration_utils import PretrainedConfig
+
+from .configuration_neo_vit import NEOVisionConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+class NEOLLMConfig(Qwen3Config):
+ def __init__(self, rope_theta_hw=10000.0, max_position_embeddings_hw=10000, **kwargs):
+ super().__init__(**kwargs)
+ self.rope_theta_hw = rope_theta_hw
+ self.max_position_embeddings_hw = max_position_embeddings_hw
+
+
+class NEOChatConfig(PretrainedConfig):
+ model_type = 'neo_chat'
+ is_composition = True
+
+ def __init__(
+ self,
+ vision_config=None,
+ llm_config=None,
+ use_backbone_lora=0,
+ use_llm_lora=0,
+ downsample_ratio=0.5,
+ template=None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ if vision_config is None:
+ vision_config = {'architectures': ['NEOVisionModel']}
+ logger.info('vision_config is None. Initializing the NEOVisionConfig with default values.')
+
+ if llm_config is None:
+ llm_config = {'architectures': ['Qwen3ForCausalLM']}
+ logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
+ assert 'architectures' in llm_config, "Should specify architecture in llm_config"
+
+ if isinstance(vision_config, dict):
+ self.vision_config = NEOVisionConfig(**vision_config)
+ else:
+ self.vision_config = vision_config
+
+ if isinstance(llm_config, dict):
+ self.llm_config = NEOLLMConfig(**llm_config)
+ else:
+ self.llm_config = llm_config
+
+ self.use_backbone_lora = use_backbone_lora
+ self.use_llm_lora = use_llm_lora
+ self.downsample_ratio = downsample_ratio
+ self.template = template
+ self.tie_word_embeddings = self.llm_config.tie_word_embeddings
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
+
+ Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output['vision_config'] = self.vision_config.to_dict()
+ output['llm_config'] = self.llm_config.to_dict()
+ output['model_type'] = self.__class__.model_type
+ output['use_backbone_lora'] = self.use_backbone_lora
+ output['use_llm_lora'] = self.use_llm_lora
+ output['downsample_ratio'] = self.downsample_ratio
+ output['template'] = self.template
+
+ return output
diff --git a/lightllm/server/x2i_server/naive/configuration_neo_vit.py b/lightllm/server/x2i_server/naive/configuration_neo_vit.py
new file mode 100644
index 0000000000..02837fea41
--- /dev/null
+++ b/lightllm/server/x2i_server/naive/configuration_neo_vit.py
@@ -0,0 +1,52 @@
+import os
+from typing import Union
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+class NEOVisionConfig(PretrainedConfig):
+
+ model_type = 'neo_vision'
+
+ def __init__(
+ self,
+ num_channels=3,
+ patch_size=16,
+ hidden_size=1024,
+ llm_hidden_size=2048,
+ downsample_ratio=0.5,
+ rope_theta_vision=10000.0,
+ max_position_embeddings_vision=10000,
+ min_pixels=65536,
+ max_pixels=4194304,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.hidden_size = hidden_size
+ self.llm_hidden_size = llm_hidden_size,
+ self.downsample_ratio = downsample_ratio,
+ self.rope_theta_vision = rope_theta_vision
+ self.max_position_embeddings_vision = max_position_embeddings_vision
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.min_pixels = min_pixels
+ self.max_pixels = max_pixels
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ if 'vision_config' in config_dict:
+ config_dict = config_dict['vision_config']
+
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
\ No newline at end of file
diff --git a/lightllm/server/x2i_server/naive/modeling_fm_modules.py b/lightllm/server/x2i_server/naive/modeling_fm_modules.py
new file mode 100644
index 0000000000..56654b2580
--- /dev/null
+++ b/lightllm/server/x2i_server/naive/modeling_fm_modules.py
@@ -0,0 +1,591 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import math
+from functools import lru_cache
+
+from torch.utils.checkpoint import checkpoint
+def modulate(x, shift, scale=None):
+ if shift is None:
+ return x * (1 + scale)
+ return x * (1 + scale) + shift
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+ return output * self.weight
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t: torch.Tensor, dim: int, max_period: float = 10000.0):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
+ device=t.device
+ )
+ args = t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
+ return t_emb
+
+class ResBlock(nn.Module):
+
+ def __init__(self, channels, mlp_ratio=1.0):
+ super().__init__()
+ self.channels = channels
+ self.intermediate_size = int(channels * mlp_ratio)
+
+ self.in_ln = nn.LayerNorm(self.channels, eps=1e-6)
+ self.mlp = nn.Sequential(
+ nn.Linear(self.channels, self.intermediate_size),
+ nn.SiLU(),
+ nn.Linear(self.intermediate_size, self.channels),
+ )
+
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
+
+ def forward(self, x, y):
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
+ h = self.mlp(h)
+ return x + gate_mlp * h
+
+# class FinalLayer(nn.Module):
+
+# def __init__(self, model_channels, out_channels):
+# super().__init__()
+# self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
+# self.linear = nn.Linear(model_channels, out_channels, bias=True)
+# self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True))
+
+# def forward(self, x, c):
+# shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
+# x = modulate(self.norm_final(x), shift, scale)
+# x = self.linear(x)
+# return x
+
+# class SimpleMLPAdaLN(nn.Module):
+
+# def __init__(self, input_dim, out_dim, dim=1536, layers=12, mlp_ratio=1.0):
+# super().__init__()
+# self.input_dim = input_dim
+# self.out_dim = out_dim
+# self.dim = dim
+# self.layers = layers
+# self.mlp_ratio = mlp_ratio
+
+# self.time_embed = TimestepEmbedder(dim)
+# self.input_proj = nn.Linear(input_dim, dim)
+
+# res_blocks = []
+# for _ in range(layers):
+# res_blocks.append(ResBlock(dim, mlp_ratio))
+# self.res_blocks = nn.ModuleList(res_blocks)
+
+# self.final_layer = FinalLayer(dim, out_dim)
+
+# self.grad_checkpointing = False
+
+# self.initialize_weights()
+
+# def initialize_weights(self):
+# def _basic_init(module):
+# if isinstance(module, nn.Linear):
+# torch.nn.init.xavier_uniform_(module.weight)
+# if module.bias is not None:
+# nn.init.constant_(module.bias, 0)
+
+# self.apply(_basic_init)
+
+# # Initialize timestep embedding MLP
+# nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
+# nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
+
+# # Zero-out adaLN modulation layers
+# for block in self.res_blocks:
+# nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+# nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+# # Zero-out output layers
+# nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
+# nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
+# nn.init.constant_(self.final_layer.linear.weight, 0)
+# nn.init.constant_(self.final_layer.linear.bias, 0)
+
+# def forward(self, x, t):
+# """
+# x.shape = (bsz, input_dim)
+# t.shape = (bsz,)
+# """
+
+# x = self.input_proj(x)
+# t = self.time_embed(t)
+
+# y = t
+
+# for block in self.res_blocks:
+# if self.grad_checkpointing and self.training:
+# x = checkpoint(block, x, y, use_reentrant=True)
+# else:
+# x = block(x, y)
+
+# return self.final_layer(x, y)
+
+class FlowMatchingHead(nn.Module):
+
+ def __init__(self, input_dim, out_dim, dim=1536, layers=12, mlp_ratio=1.0):
+ super(FlowMatchingHead, self).__init__()
+ self.net = SimpleMLPAdaLN(input_dim=input_dim, out_dim=out_dim, dim=dim, layers=layers, mlp_ratio=mlp_ratio)
+
+ @property
+ def dtype(self):
+ return self.net.input_proj.weight.dtype
+
+ @property
+ def device(self):
+ return self.net.input_proj.weight.device
+
+ def forward(self, x, t):
+ x = self.net(x, t)
+ return x
+
+
+def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+ # assert H * H == end
+ # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
+ x_pos = torch.linspace(0, scale, width)
+ y_pos = torch.linspace(0, scale, height)
+ y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
+ y_pos = y_pos.reshape(-1)
+ x_pos = x_pos.reshape(-1)
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
+ y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height*width, -1)
+ return freqs_cis
+
+class NerfEmbedder(nn.Module):
+ def __init__(self, in_channels, hidden_size_input, max_freqs):
+ super().__init__()
+ self.max_freqs = max_freqs
+ self.hidden_size_input = hidden_size_input
+ self.embedder = nn.Sequential(
+ nn.Linear(in_channels+max_freqs**2, hidden_size_input, bias=True),
+ )
+
+ @lru_cache
+ def fetch_pos(self, patch_size, device, dtype):
+ pos = precompute_freqs_cis_2d(self.max_freqs ** 2 * 2, patch_size, patch_size).real
+ pos = pos[None, :, :].to(device=device, dtype=dtype)
+ return pos
+
+
+ def forward(self, inputs):
+ B, P2, C = inputs.shape
+ patch_size = int(P2 ** 0.5)
+ device = inputs.device
+ dtype = inputs.dtype
+ dct = self.fetch_pos(patch_size, device, dtype)
+ dct = dct.repeat(B, 1, 1)
+ inputs = torch.cat([inputs, dct], dim=-1)
+ inputs = self.embedder(inputs)
+ return inputs
+
+class SimpleMLPAdaLN(nn.Module):
+ """
+ The MLP for Diffusion Loss.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param z_channels: channels in the condition.
+ :param num_res_blocks: number of residual blocks per downsample.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ z_channels,
+ num_res_blocks,
+ patch_size,
+ grad_checkpointing=False
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.grad_checkpointing = grad_checkpointing
+ self.patch_size = patch_size
+
+ self.cond_embed = nn.Linear(z_channels, patch_size**2*model_channels)
+
+ self.input_proj = nn.Linear(in_channels, model_channels)
+
+ res_blocks = []
+ for i in range(num_res_blocks):
+ res_blocks.append(ResBlock(
+ model_channels,
+ ))
+
+ self.res_blocks = nn.ModuleList(res_blocks)
+ self.final_layer = FinalLayer(model_channels, out_channels)
+
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+ self.apply(_basic_init)
+
+ # Zero-out adaLN modulation layers
+ for block in self.res_blocks:
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
+
+ # Zero-out output layers
+ nn.init.constant_(self.final_layer.linear.weight, 0)
+ nn.init.constant_(self.final_layer.linear.bias, 0)
+
+ def forward(self, x, c):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C] Tensor of inputs.
+ :param t: a 1-D batch of timesteps.
+ :param c: conditioning from AR transformer.
+ :return: an [N x C] Tensor of outputs.
+ """
+ x = self.input_proj(x)
+ c = self.cond_embed(c)
+
+ y = c.reshape(-1, self.patch_size**2, self.model_channels)
+
+ for block in self.res_blocks:
+ x = block(x, y)
+
+ return self.final_layer(x)
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer adopted from DiT.
+ """
+ def __init__(self, model_channels, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
+
+ def forward(self, x):
+ x = self.norm_final(x)
+ x = self.linear(x)
+ return x
+
+#################################################################################
+# Sine/Cosine Positional Embedding Functions #
+#################################################################################
+# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32) / pe_interpolation
+ grid_w = np.arange(grid_size, dtype=np.float32) / pe_interpolation
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model_path, pe_key: str = "gen_pos_embed", new_len: int = 4096):
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ pos_embed_1d = state_dict[pe_key]
+ _, ori_len, embed_dim = pos_embed_1d.shape
+
+ ori_size = int(ori_len**0.5)
+ new_size = int(new_len**0.5)
+
+ if ori_size != new_size:
+ logger.info("Position interpolate from %dx%d to %dx%d" % (ori_size, ori_size, new_size, new_size))
+ pos_embed_2d = pos_embed_1d.reshape(-1, ori_size, ori_size, embed_dim).permute(0, 3, 1, 2)
+ pos_embed_2d = torch.nn.functional.interpolate(
+ pos_embed_2d, size=(new_size, new_size), mode="bicubic", align_corners=False
+ )
+ pos_embed_1d = pos_embed_2d.permute(0, 2, 3, 1).flatten(1, 2)
+ state_dict[pe_key] = pos_embed_1d
+
+ torch.save(state_dict, model_path)
+
+class PositionEmbedding(nn.Module):
+ def __init__(self, max_num_patch_per_side, hidden_size):
+ super().__init__()
+ self.max_num_patch_per_side = max_num_patch_per_side
+ self.hidden_size = hidden_size
+ self.pos_embed = nn.Parameter(
+ torch.zeros(max_num_patch_per_side ** 2, hidden_size),
+ requires_grad=False
+ )
+ self._init_weights()
+
+ def _init_weights(self):
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
+ pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side)
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
+
+ def forward(self, position_ids):
+ return self.pos_embed[position_ids]
+
+
+class ResidualConvBlock(nn.Module):
+ def __init__(self, channels: int):
+ super().__init__()
+ self.block = nn.Sequential(
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
+ nn.SiLU(),
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
+ )
+ nn.init.zeros_(self.block[2].weight)
+ nn.init.zeros_(self.block[2].bias)
+
+ def forward(self, x):
+ return x + self.block(x)
+
+
+class PostConvSmoother(nn.Module):
+ def __init__(self, in_channels=3, hidden_channels=64, num_blocks=3):
+ super().__init__()
+ self.in_proj = nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1)
+ self.blocks = nn.Sequential(*[ResidualConvBlock(hidden_channels) for _ in range(num_blocks)])
+ self.out_proj = nn.Conv2d(hidden_channels, in_channels, kernel_size=1)
+
+ nn.init.zeros_(self.out_proj.weight)
+ nn.init.zeros_(self.out_proj.bias)
+
+ def forward(self, x):
+ h = self.in_proj(x)
+ h = self.blocks(h)
+ return x + self.out_proj(h)
+
+
+class ProgressiveConvDecoder(nn.Module):
+ def __init__(self, hidden_dim=4096, out_channels=3):
+ super().__init__()
+
+ # self.proj = nn.Linear(hidden_dim, 1024)
+ # self.act = nn.SiLU()
+
+ self.up_blocks = nn.ModuleList([
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='nearest'),
+ nn.Conv2d(hidden_dim, 512, kernel_size=3, padding=1),
+ nn.GroupNorm(32, 512),
+ nn.SiLU()
+ ),
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='nearest'),
+ nn.Conv2d(512, 256, kernel_size=3, padding=1),
+ nn.GroupNorm(32, 256),
+ nn.SiLU()
+ ),
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='nearest'),
+ nn.Conv2d(256, 64, kernel_size=3, padding=1),
+ nn.GroupNorm(32, 64),
+ nn.SiLU()
+ ),
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='nearest'),
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
+ nn.GroupNorm(16, 32),
+ nn.SiLU()
+ ),
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode='nearest'),
+ nn.Conv2d(32, 16, kernel_size=3, padding=1),
+ nn.SiLU()
+ )
+ ])
+
+ self.out_conv = nn.Conv2d(16, out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x_2d):
+ # B, C, H, W = x_2d.shape
+ # x = x_2d.permute(0, 2, 3, 1).contiguous() # (B, H, W, C)
+ # x = self.proj(x)
+ # x = self.act(x)
+ # x = x.permute(0, 3, 1, 2).contiguous() # (B, 512, H, W)
+ x = x_2d
+ for block in self.up_blocks:
+ x = block(x)
+
+ out = self.out_conv(x)
+ return out
+
+
+class PatchDecoder_postps(nn.Module):
+ def __init__(self):
+ super().__init__()
+ # layer 1: H/32 -> H/8 (4x upscale)
+
+ self.conv1 = nn.Conv2d(4096, 4096, kernel_size=3, padding=1)
+ self.ps1 = nn.PixelShuffle(4)
+ self.act1 = nn.GELU()
+
+ # layer 2: H/8 -> H (8x upscale)
+ self.conv2 = nn.Conv2d(256, 192, kernel_size=3, padding=1)
+ self.ps2 = nn.PixelShuffle(8)
+
+ def forward(self, x):
+ # x shape: [B, 4096, H/32, W/32]
+ x = self.ps1(self.act1(self.conv1(x))) # -> [B, 256, H/8, W/8]
+ x = self.ps2(self.conv2(x)) # -> [B, 3, H, W]
+ return x
+
+
+class PatchDecoder_preps(nn.Module):
+ def __init__(self):
+ super().__init__()
+ # layer 1: H/32 -> H/16 (2x upscale)
+ self.ps1 = nn.PixelShuffle(2)
+ self.conv1 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
+ self.act1 = nn.GELU()
+
+ # layer 2: H/16 -> H/8 (2x upscale)
+ self.ps2 = nn.PixelShuffle(2)
+ self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
+ self.act2 = nn.GELU()
+
+ # layer 3: H/8 -> H (8x upscale)
+ self.ps3 = nn.PixelShuffle(8)
+ self.conv3 = nn.Conv2d(4, 3, kernel_size=3, padding=1)
+
+ def forward(self, x):
+ # x shape: [B, 4096, H/32, W/32]
+ x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16]
+ x = self.act2(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8]
+ x = self.conv3(self.ps3((x))) # -> [B, 3, H, W]
+ return x
+
+class PatchDecoder_preps1(nn.Module):
+ def __init__(self):
+ super().__init__()
+ # layer 1: H/32 -> H/16 (2x upscale)
+ self.ps1 = nn.PixelShuffle(2)
+ self.conv1 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
+ self.act1 = nn.GELU()
+
+ # layer 2: H/16 -> H/8 (2x upscale)
+ self.ps2 = nn.PixelShuffle(2)
+ self.conv2 = nn.Conv2d(256, 192, kernel_size=3, padding=1)
+
+ # layer 3: H/8 -> H (8x upscale)
+ self.ps3 = nn.PixelShuffle(8)
+
+ def forward(self, x):
+ # x shape: [B, 4096, H/32, W/32]
+ x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16]
+ x = self.ps3(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8]
+ return x
+
+class ConvDecoder(nn.Module):
+ def __init__(self, input_dim=4096, hidden_dim=1024):
+ super().__init__()
+ # layer 1: H/32 -> H/16 (2x upscale)
+ self.ps1 = nn.PixelShuffle(2)
+ self.conv1 = nn.Conv2d(input_dim // 4, hidden_dim, kernel_size=3, padding=1)
+ self.act1 = nn.GELU()
+
+ # layer 2: H/16 -> H/8 (2x upscale)
+ self.ps2 = nn.PixelShuffle(2)
+ self.conv2 = nn.Conv2d(hidden_dim // 4, 192, kernel_size=3, padding=1)
+
+ # layer 3: H/8 -> H (8x upscale)
+ self.ps3 = nn.PixelShuffle(8)
+
+ def forward(self, x):
+ x = self.act1(self.conv1(self.ps1((x))))
+ x = self.ps3(self.conv2(self.ps2((x))))
+ return x
diff --git a/lightllm/server/x2i_server/naive/modeling_neo_chat.py b/lightllm/server/x2i_server/naive/modeling_neo_chat.py
new file mode 100644
index 0000000000..890d761391
--- /dev/null
+++ b/lightllm/server/x2i_server/naive/modeling_neo_chat.py
@@ -0,0 +1,755 @@
+from typing import List, Optional, Tuple, Union
+import math
+import torch.utils.checkpoint
+from torch import nn
+import transformers
+import numpy as np
+import base64
+from PIL import Image
+from torch.nn import CrossEntropyLoss
+from transformers import GenerationConfig
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+from transformers.cache_utils import Cache, DynamicCache
+from lightllm.server.core.objs.x2i_params import X2IParams
+import torchvision.io as io
+
+from .configuration_neo_chat import NEOChatConfig
+from .modeling_neo_vit import NEOVisionModel
+from .modeling_qwen3 import Qwen3ForCausalLM, create_block_causal_mask
+from .modeling_fm_modules import PositionEmbedding, TimestepEmbedder, FlowMatchingHead, RMSNorm, NerfEmbedder, SimpleMLPAdaLN, ConvDecoder
+
+
+logger = logging.get_logger(__name__)
+
+
+def version_cmp(v1, v2, op='eq'):
+ import operator
+
+ from packaging import version
+ op_func = getattr(operator, op)
+ return op_func(version.parse(v1), version.parse(v2))
+
+def prepare_flash_kv_cache(
+ past_key_values,
+ current_len: int,
+ batch_size: int,
+):
+ """
+ Convert prefix cache from [B, H, S, D] to flash-attn friendly [B, S, H, D],
+ and preallocate full KV buffer for [prefix + current].
+
+ This is done once before denoising loop.
+ """
+ if past_key_values is None:
+ return
+
+ for layer in past_key_values.layers:
+ past_k = layer.keys
+ past_v = layer.values
+
+ if past_k is None or past_v is None:
+ layer.flash_prefix_len = 0
+ layer.flash_total_len = current_len
+ layer.flash_k_cache = None
+ layer.flash_v_cache = None
+ continue
+
+ # original cache layout assumed: [B, H, S, D]
+ past_k_flash = past_k.transpose(1, 2).contiguous() # [B, S, H, D]
+ past_v_flash = past_v.transpose(1, 2).contiguous() # [B, S, H, D]
+
+ prefix_len = past_k_flash.shape[1]
+ total_len = prefix_len + current_len
+
+ k_cache = torch.empty(
+ (batch_size, total_len, past_k_flash.shape[2], past_k_flash.shape[3]),
+ device=past_k_flash.device,
+ dtype=past_k_flash.dtype,
+ )
+ v_cache = torch.empty(
+ (batch_size, total_len, past_v_flash.shape[2], past_v_flash.shape[3]),
+ device=past_v_flash.device,
+ dtype=past_v_flash.dtype,
+ )
+
+ k_cache[:, :prefix_len].copy_(past_k_flash)
+ v_cache[:, :prefix_len].copy_(past_v_flash)
+
+ layer.flash_prefix_len = prefix_len
+ layer.flash_total_len = total_len
+ layer.flash_k_cache = k_cache
+ layer.flash_v_cache = v_cache
+
+def clear_flash_kv_cache(past_key_values):
+ if past_key_values is None:
+ return
+ for layer in past_key_values.layers:
+ if hasattr(layer, "flash_prefix_len"):
+ delattr(layer, "flash_prefix_len")
+ if hasattr(layer, "flash_total_len"):
+ delattr(layer, "flash_total_len")
+ if hasattr(layer, "flash_k_cache"):
+ delattr(layer, "flash_k_cache")
+ if hasattr(layer, "flash_v_cache"):
+ delattr(layer, "flash_v_cache")
+
+
+@torch.cuda.amp.autocast(dtype=torch.float32)
+def optimized_scale(positive_flat, negative_flat):
+
+ # Calculate dot production
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
+
+ # Squared norm of uncondition
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
+
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
+ st_star = dot_product / squared_norm
+
+ return st_star
+
+
+def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
+ """
+ Compute patch coordinates (x, y)
+
+ Args:
+ grid_hw: (B, 2) tensor representing (H, W) per image
+ """
+ device = grid_hw.device
+ B = grid_hw.shape[0]
+
+ # Get the number of patches per image
+ H = grid_hw[:, 0]
+ W = grid_hw[:, 1]
+ N = H * W
+ N_total = N.sum()
+
+ # Create the batch index for each patch (B x patch count)
+ patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,)
+
+ # Generate intra-image patch index (row-major order)
+ patch_id_within_image = torch.arange(N_total, device=device)
+ patch_id_within_image = patch_id_within_image - torch.cumsum(
+ torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0
+ )[patch_to_sample]
+
+ # Get H/W for each patch according to its image
+ W_per_patch = W[patch_to_sample]
+ abs_x = patch_id_within_image % W_per_patch
+ abs_y = patch_id_within_image // W_per_patch
+
+ return abs_x, abs_y
+
+
+class NEOChatModel(PreTrainedModel):
+ config_class = NEOChatConfig
+ main_input_name = 'pixel_values'
+ base_model_prefix = 'language_model'
+ _supports_flash_attn_2 = True
+ supports_gradient_checkpointing = True
+ _no_split_modules = [
+ "NEOVisionModel",
+ "Qwen3DecoderLayer",
+ ]
+
+ # support transformers 4.51.+
+ _tp_plan = ''
+
+ def __init__(self, config: NEOChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
+ super().__init__(config)
+
+ assert version_cmp(transformers.__version__, '4.37.0', 'ge')
+ patch_size = config.vision_config.patch_size
+ self.patch_size = patch_size
+ self.template = config.template
+ self.downsample_ratio = config.downsample_ratio
+ config.llm_config._attn_implementation = 'eager'
+
+ if vision_model is not None:
+ self.vision_model = vision_model
+ else:
+ self.vision_model = NEOVisionModel(config.vision_config)
+ vision_model_mot_gen = NEOVisionModel(config.vision_config)
+ if language_model is not None:
+ self.language_model = language_model
+ else:
+ self.language_model = Qwen3ForCausalLM(config.llm_config)
+
+ merge_size = int(1 / self.downsample_ratio)
+ output_dim = 3*(patch_size*merge_size)**2
+ llm_hidden_size = self.config.llm_config.hidden_size
+ self.use_deep_fm_head = self.config.fm_head_layers > 2
+ self.use_pixel_head = self.config.use_pixel_head
+
+ if self.use_deep_fm_head:
+ fm_head = FlowMatchingHead(llm_hidden_size, output_dim, dim=self.config.fm_head_dim, layers=self.config.fm_head_layers, mlp_ratio=self.config.fm_head_mlp_ratio)
+ else:
+ fm_head = nn.Sequential(
+ nn.Linear(llm_hidden_size, 4096, bias=True),
+ nn.GELU(),
+ nn.Linear(4096, output_dim, bias=True),
+ )
+
+ timestep_embedder = TimestepEmbedder(llm_hidden_size)
+ self.fm_modules = nn.ModuleDict(
+ {
+ "vision_model_mot_gen": vision_model_mot_gen,
+ "timestep_embedder": timestep_embedder,
+ "fm_head": fm_head
+ }
+ )
+ if self.use_pixel_head:
+ self.fm_modules["fm_head"] = ConvDecoder(llm_hidden_size)
+
+
+ self.concat_time_token_num = config.concat_time_token_num
+ self.time_token_id = 151682
+ self.noise_scale = config.noise_scale
+ self.noise_scale_mode = config.noise_scale_mode
+ self.noise_scale_base_image_seq_len = config.noise_scale_base_image_seq_len
+
+ self.add_noise_scale_embedding = config.add_noise_scale_embedding
+ self.noise_scale_max_value = config.noise_scale_max_value
+ self.time_schedule = config.time_schedule
+ self.time_shift_type = config.time_shift_type
+ self.base_shift = config.base_shift
+ self.max_shift = config.max_shift
+ self.base_image_seq_len = config.base_image_seq_len
+ self.max_image_seq_len = config.max_image_seq_len
+
+ if self.add_noise_scale_embedding:
+ noise_scale_embedder = TimestepEmbedder(llm_hidden_size)
+ self.fm_modules['noise_scale_embedder'] = noise_scale_embedder
+
+
+
+ self.img_context_token_id = None
+ self.img_start_token_id = 151670
+ # self.conv_template = get_conv_template(self.template)
+ # self.system_message = self.conv_template.system_message
+
+
+ def extract_feature(self, pixel_values, gen_model=False, grid_hw=None):
+ if gen_model:
+ return self.fm_modules['vision_model_mot_gen'](pixel_values=pixel_values,
+ output_hidden_states=False,
+ return_dict=True,
+ grid_hw=grid_hw).last_hidden_state
+ else:
+ return self.vision_model(pixel_values=pixel_values,
+ output_hidden_states=False,
+ return_dict=True,
+ grid_hw=grid_hw).last_hidden_state
+
+ def patchify(self, images, patch_size, channel_first=False):
+ """
+ images: (N, 3, H, W)
+ x: (N, L, patch_size**2 *3)
+ """
+ h, w = images.shape[2] // patch_size, images.shape[3] // patch_size
+ x = images.reshape(shape=(images.shape[0], 3, h, patch_size, w, patch_size))
+
+ if channel_first:
+ x = torch.einsum('nchpwq->nhwcpq', x)
+ else:
+ x = torch.einsum('nchpwq->nhwpqc', x)
+
+ x = x.reshape(shape=(images.shape[0], h * w, patch_size**2 * 3))
+ return x
+
+ def unpatchify(sle, x, patch_size, h=None, w=None):
+ """
+ x: (N, L, patch_size**2 *3)
+ images: (N, 3, H, W)
+ """
+ if h is None or w is None:
+ h = w = int(x.shape[1]**.5)
+ else:
+ h = h // patch_size
+ w = w // patch_size
+ x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, 3))
+ x = torch.einsum('nhwpqc->nchpwq', x)
+ images = x.reshape(shape=(x.shape[0], 3, h * patch_size, w * patch_size))
+ return images
+
+ def _euler_step(self, v_pred, z, t, t_next):
+ z_next = z + (t_next - t) * v_pred
+ return z_next
+
+ def _calculate_dynamic_mu(self, image_seq_len: int) -> float:
+ denom = self.max_image_seq_len - self.base_image_seq_len
+ if denom == 0:
+ return float(self.base_shift)
+ m = (self.max_shift - self.base_shift) / denom
+ b = self.base_shift - m * self.base_image_seq_len
+ return float(image_seq_len) * m + b
+
+ def _apply_time_schedule(self, t: torch.Tensor, image_seq_len: int, timestep_shift: float) -> torch.Tensor:
+ self.time_schedule = "standard"
+ sigma = 1 - t
+ if timestep_shift != 1:
+ self.time_schedule = "standard"
+ if self.time_schedule == "standard":
+ shift = timestep_shift
+ sigma = shift * sigma / (1 + (shift - 1) * sigma)
+ elif self.time_schedule == "dynamic":
+ mu = self._calculate_dynamic_mu(image_seq_len)
+ mu_t = t.new_tensor(mu)
+ if self.time_shift_type == "exponential":
+ shift = torch.exp(mu_t)
+ sigma = shift * sigma / (1 + (shift - 1) * sigma)
+ elif self.time_shift_type == "linear":
+ sigma = mu_t / (mu_t + (1 / sigma - 1))
+ else:
+ raise ValueError(f"Unsupported time_shift_type: {self.time_shift_type}")
+ else:
+ raise ValueError(f"Unsupported time_schedule: {self.time_schedule}")
+ return 1 - sigma
+
+ def _build_t2i_image_indexes(self, token_h, token_w, text_len, device):
+ t_image = torch.full((token_h * token_w,), text_len, dtype=torch.long, device=device)
+ idx = torch.arange(token_h * token_w, device=device, dtype=torch.long)
+ h_image = idx // token_w
+ w_image = idx % token_w
+ return torch.stack([t_image, h_image, w_image], dim=0)
+
+
+
+ def _t2i_predict_v(self, input_embeds, indexes_image, attn_mask, past_key_values, t, z,
+ image_token_num, timestep_embeddings=None, image_size=None):
+ B, L = z.shape[0], z.shape[1]
+
+ outputs = self.language_model.model(
+ inputs_embeds=input_embeds,
+ image_gen_indicators=torch.ones((input_embeds.shape[0], input_embeds.shape[1]), dtype=torch.bool, device=input_embeds.device),
+ indexes=indexes_image,
+ attention_mask=attn_mask,
+ past_key_values=past_key_values,
+ update_cache=False,
+ use_cache=True,
+ )
+
+ if self.use_pixel_head:
+ merge_size = int(1 / self.downsample_ratio)
+ token_h = image_size[1] // (self.patch_size * merge_size)
+ token_w = image_size[0] // (self.patch_size * merge_size)
+
+ img_reshaped = outputs.last_hidden_state[:, -image_token_num:].view(B, token_h, token_w, -1)
+ img_2d = torch.einsum("b h w c -> b c h w", img_reshaped)
+ img_2d = img_2d.contiguous().view(B, -1, token_h, token_w)
+
+ smoothed_img_2d = self.fm_modules['fm_head'](img_2d)
+
+ smoothed_reshaped = smoothed_img_2d.view(B, 3, token_h, self.patch_size * merge_size, token_w, self.patch_size * merge_size)
+ smoothed_reshaped = torch.einsum("b c h p w q -> b h w p q c", smoothed_reshaped)
+ out_1d = smoothed_reshaped.contiguous().view(B, L, self.patch_size * merge_size * self.patch_size * merge_size * 3)
+ x_pred = out_1d
+ else:
+ if self.use_deep_fm_head:
+ x_pred = self.fm_modules["fm_head"](
+ outputs.last_hidden_state[:, -image_token_num:].view(B*L, -1), t.repeat(B*L)
+ ).view(B, L, -1)
+ else:
+ x_pred = self.fm_modules["fm_head"](
+ outputs.last_hidden_state[:, -image_token_num:].view(B, L, -1)
+ ).view(B, L, -1)
+
+
+ v_pred = (x_pred - z) / (1 - t).clamp_min(self.config.t_eps)
+ return v_pred
+
+
+ @torch.no_grad()
+ def it2i_generate(self,
+ past_key_values_condition,
+ past_key_values_text_uncondition,
+ past_key_values_img_uncondition,
+ text_lens,
+ cfg_scale=1,
+ img_cfg_scale=1,
+ cfg_norm='none',
+ enable_timestep_shift=True,
+ timestep_shift=1,
+ image_size=(256, 256),
+ num_steps=30,
+ cfg_interval=(0.1, 1.0),
+ batch_size=1,
+ t_eps=0.02,
+ ):
+
+ self.config.t_eps = t_eps
+ device, dtype = self.get_cache_device_dtype(past_key_values_condition)
+ S1, S2, S3 = text_lens
+
+ merge_size = int(1 / self.downsample_ratio)
+
+ token_h = image_size[1] // (self.patch_size * merge_size)
+ token_w = image_size[0] // (self.patch_size * merge_size)
+
+ indexes_image_condition = self._build_t2i_image_indexes(token_h, token_w, S1, device=device)
+ indexes_image_text_uncondition = self._build_t2i_image_indexes(token_h, token_w, S2, device=device)
+ indexes_image_img_uncondition = self._build_t2i_image_indexes(token_h, token_w, S3, device=device)
+
+ for layer_idx in range(len(past_key_values_condition.layers)):
+ past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:])
+ past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[layer_idx].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:])
+ past_key_values_text_uncondition.layers[layer_idx].keys = past_key_values_text_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].keys.shape[1:])
+ past_key_values_text_uncondition.layers[layer_idx].values = past_key_values_text_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].values.shape[1:])
+ past_key_values_img_uncondition.layers[layer_idx].keys = past_key_values_img_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].keys.shape[1:])
+ past_key_values_img_uncondition.layers[layer_idx].values = past_key_values_img_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].values.shape[1:])
+
+ prepare_flash_kv_cache(
+ past_key_values_condition,
+ current_len=token_h * token_w,
+ batch_size=batch_size,
+ )
+ prepare_flash_kv_cache(
+ past_key_values_text_uncondition,
+ current_len=token_h * token_w,
+ batch_size=batch_size,
+ )
+ prepare_flash_kv_cache(
+ past_key_values_img_uncondition,
+ current_len=token_h * token_w,
+ batch_size=batch_size,
+ )
+
+
+ # init noise image tokens
+ grid_h = image_size[1] // self.patch_size
+ grid_w = image_size[0] // self.patch_size
+ grid_hw = torch.tensor([[grid_h, grid_w]]*batch_size, device=device)
+
+ noise_scale = self.noise_scale
+ if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'):
+ noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2) / self.noise_scale_base_image_seq_len)
+ base = float(self.noise_scale_base_image_seq_len)
+ scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base)
+ noise_scale = scale * float(self.noise_scale)
+ if self.noise_scale_mode == 'dynamic_sqrt':
+ noise_scale = math.sqrt(noise_scale)
+ noise_scale = min(noise_scale, self.noise_scale_max_value)
+
+ image_prediction = noise_scale * torch.randn((batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype)
+
+ attention_mask_condition = {"full_attention": None}
+ attention_mask_text_uncondition = {"full_attention": None}
+ attention_mask_img_uncondition = {"full_attention": None}
+
+ timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device)
+ if enable_timestep_shift:
+ timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift)
+
+ for step_i in range(num_steps):
+ t = timesteps[step_i]
+ t_next = timesteps[step_i + 1]
+
+ z = self.patchify(image_prediction, self.patch_size * merge_size)
+ image_input = self.patchify(image_prediction, self.patch_size, channel_first=True)
+ image_embeds = self.extract_feature(image_input.view(batch_size * grid_h*grid_w, -1), gen_model=True, grid_hw=grid_hw).view(batch_size, token_h*token_w, -1)
+ t_expanded = t.expand(batch_size*token_h*token_w)
+ timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(batch_size, token_h*token_w, -1)
+ if self.add_noise_scale_embedding:
+ noise_scale_tensor = torch.full_like(t_expanded, noise_scale/self.noise_scale_max_value)
+ noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(batch_size, token_h*token_w, -1)
+ timestep_embeddings += noise_embeddings
+ image_embeds = image_embeds + timestep_embeddings
+
+ v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size)
+ if t > cfg_interval[0] and t < cfg_interval[1]:
+ if cfg_scale > 1:
+ v_pred_text_uncondition = self._t2i_predict_v(image_embeds, indexes_image_text_uncondition, attention_mask_text_uncondition, past_key_values_text_uncondition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size)
+ else:
+ v_pred_text_uncondition = 0
+ if img_cfg_scale > 1:
+ v_pred_img_uncondition = self._t2i_predict_v(image_embeds, indexes_image_img_uncondition, attention_mask_img_uncondition, past_key_values_img_uncondition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size)
+ else:
+ v_pred_img_uncondition = 0
+
+ if t > cfg_interval[0] and t < cfg_interval[1]:
+ v_pred_text = v_pred_text_uncondition + cfg_scale * (v_pred_condition - v_pred_text_uncondition)
+ if cfg_norm == 'text_channel':
+ norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True)
+ norm_v_cfg = torch.norm(v_pred_text, dim=-1, keepdim=True)
+ scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
+ v_pred_text = v_pred_text * scale
+ v_pred = v_pred_img_uncondition + img_cfg_scale * (v_pred_text - v_pred_img_uncondition)
+ if cfg_norm == 'global':
+ norm_v_condition = torch.norm(v_pred_condition, dim=(1,2), keepdim=True)
+ norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True)
+ scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
+ v_pred = v_pred * scale
+ elif cfg_norm == 'channel':
+ norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True)
+ norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True)
+ scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
+ v_pred = v_pred * scale
+
+ else:
+ v_pred = v_pred_condition
+
+ z = z + (t_next - t) * v_pred
+
+ image_prediction = self.unpatchify(z, self.patch_size * merge_size, image_size[1], image_size[0])
+
+ clear_flash_kv_cache(past_key_values_condition)
+ clear_flash_kv_cache(past_key_values_text_uncondition)
+ clear_flash_kv_cache(past_key_values_img_uncondition)
+
+ return image_prediction
+
+
+ def get_cache_device_dtype(self, cache):
+ """
+ Returns (device, dtype) of a DynamicCache.
+ Assumes all layers share same device/dtype.
+ """
+ for layer in cache.layers:
+ return layer.device, layer.dtype
+ raise ValueError("Cache is empty")
+
+ @torch.no_grad()
+ def t2i_generate(self,
+ past_key_values_condition,
+ past_key_values_uncondition,
+ text_lens,
+ cfg_scale=1,
+ timestep_shift=1,
+ enable_timestep_shift=True,
+ cfg_norm='none',
+ image_size=(256, 256),
+ num_steps=30,
+ cfg_interval=(0.1, 1.0),
+ batch_size=1,
+ t_eps=0.02):
+ assert cfg_norm in ['cfg_zero_star', 'global', 'none'], f"cfg_norm={cfg_norm}"
+ merge_size = int(1 / self.downsample_ratio)
+ self.config.t_eps = t_eps
+
+ token_h = image_size[1] // (self.patch_size * merge_size)
+ token_w = image_size[0] // (self.patch_size * merge_size)
+
+ device, dtype = self.get_cache_device_dtype(past_key_values_condition)
+ S1, S2 = text_lens
+
+ indexes_image_condition = self._build_t2i_image_indexes(token_h, token_w, S1, device=device)
+ indexes_image_uncondition = self._build_t2i_image_indexes(token_h, token_w, S2, device=device)
+
+ for layer_idx in range(len(past_key_values_condition.layers)):
+ past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:])
+ past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[layer_idx].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:])
+ past_key_values_uncondition.layers[layer_idx].keys = past_key_values_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].keys.shape[1:])
+ past_key_values_uncondition.layers[layer_idx].values = past_key_values_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].values.shape[1:])
+
+ # prepare flash cache once
+ prepare_flash_kv_cache(
+ past_key_values_condition,
+ current_len=token_h * token_w,
+ batch_size=batch_size,
+ )
+ prepare_flash_kv_cache(
+ past_key_values_uncondition,
+ current_len=token_h * token_w,
+ batch_size=batch_size,
+ )
+
+ # init noise image tokens
+ grid_h = image_size[1] // self.patch_size
+ grid_w = image_size[0] // self.patch_size
+ grid_hw = torch.tensor([[grid_h, grid_w]]*batch_size, device=device)
+
+ noise_scale = self.noise_scale
+ if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'):
+ noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2) / self.noise_scale_base_image_seq_len)
+ base = float(self.noise_scale_base_image_seq_len)
+ scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base)
+ noise_scale = scale * float(self.noise_scale)
+ if self.noise_scale_mode == 'dynamic_sqrt':
+ noise_scale = math.sqrt(noise_scale)
+ noise_scale = min(noise_scale, self.noise_scale_max_value)
+
+ image_prediction = noise_scale * torch.randn((batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype)
+
+ attention_mask_condition = {"full_attention": None}
+ attention_mask_uncondition = {"full_attention": None}
+
+ timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device)
+
+ if enable_timestep_shift:
+ timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift)
+
+ for step_i in range(num_steps):
+ t = timesteps[step_i]
+ t_next = timesteps[step_i + 1]
+
+ z = self.patchify(image_prediction, self.patch_size * merge_size)
+ image_input = self.patchify(image_prediction, self.patch_size, channel_first=True)
+ image_embeds = self.extract_feature(image_input.view(batch_size * grid_h*grid_w, -1), gen_model=True, grid_hw=grid_hw).view(batch_size, token_h*token_w, -1)
+ t_expanded = t.expand(batch_size*token_h*token_w)
+ timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(batch_size, token_h*token_w, -1)
+ if self.add_noise_scale_embedding:
+ noise_scale_tensor = torch.full_like(t_expanded, noise_scale / self.noise_scale_max_value)
+ noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(batch_size, token_h*token_w, -1)
+ timestep_embeddings += noise_embeddings
+ image_embeds = image_embeds + timestep_embeddings
+
+
+ v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w,
+ timestep_embeddings=timestep_embeddings, image_size=image_size)
+
+
+ if t > cfg_interval[0] and t < cfg_interval[1] and cfg_scale > 1:
+ v_pred_uncondition = self._t2i_predict_v(image_embeds, indexes_image_uncondition, attention_mask_uncondition, past_key_values_uncondition, t, z, image_token_num=token_h*token_w,
+ timestep_embeddings=timestep_embeddings, image_size=image_size)
+ if cfg_norm == 'cfg_zero_star':
+ positive_flat = v_pred_condition.view(batch_size, -1)
+ negative_flat = v_pred_uncondition.view(batch_size, -1)
+
+ alpha = optimized_scale(positive_flat,negative_flat)
+ alpha = alpha.view(batch_size, *([1] * (len(v_pred_condition.shape) - 1)))
+ alpha = alpha.to(positive_flat.dtype)
+
+ if (step_i <= 0):
+ v_pred = v_pred_condition*0.
+ else:
+ v_pred = v_pred_uncondition * alpha + cfg_scale * (v_pred_condition - v_pred_uncondition * alpha)
+ else:
+ v_pred = v_pred_uncondition + cfg_scale * (v_pred_condition - v_pred_uncondition)
+ if cfg_norm == 'global':
+ norm_v_condition = torch.norm(v_pred_condition, dim=(1,2), keepdim=True)
+ norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True)
+ scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
+ v_pred = v_pred * scale
+ else:
+ v_pred = v_pred_condition
+
+ z = z + (t_next - t) * v_pred
+
+ image_prediction = self.unpatchify(z, self.patch_size * merge_size, image_size[1], image_size[0])
+
+ clear_flash_kv_cache(past_key_values_condition)
+ clear_flash_kv_cache(past_key_values_uncondition)
+
+ return image_prediction
+
+
+ @property
+ def lm_head(self):
+ return self.language_model.get_output_embeddings()
+
+ def get_output_embeddings(self):
+ return self.language_model.get_output_embeddings()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ return self.language_model.set_input_embeddings(value)
+
+ def set_output_embeddings(self, value):
+ return self.language_model.set_output_embeddings(value)
+
+ def get_thw_indexes(self, input_ids, grid_hw=None):
+ img_start_shift = torch.cat([torch.zeros(1, dtype=torch.long).to(input_ids.device),
+ (input_ids == self.img_start_token_id).long()], dim=0)[:-1]
+ not_img_token = (input_ids != self.img_context_token_id).long()
+ t_indexes = ((img_start_shift + not_img_token).cumsum(0) - 1)
+ h_indexes = torch.zeros_like(t_indexes).to(t_indexes.device)
+ w_indexes = torch.zeros_like(t_indexes).to(t_indexes.device)
+
+ if grid_hw is not None:
+ selected = (input_ids == self.img_context_token_id)
+ if selected.long().sum() > 0:
+ abs_pos_w, abs_pos_h = build_abs_positions_from_grid_hw(
+ grid_hw // int(1 / self.downsample_ratio), device=t_indexes.device)
+ h_indexes[selected] = abs_pos_h.to(t_indexes.device, t_indexes.dtype)
+ w_indexes[selected] = abs_pos_w.to(t_indexes.device, t_indexes.dtype)
+ return torch.stack([t_indexes, h_indexes, w_indexes], dim=0)
+
+
+NORM_MEAN = [0.5, 0.5, 0.5]
+NORM_STD = [0.5, 0.5, 0.5]
+
+class NEOX2I:
+ def __init__(self, model_path, device):
+ self.device = device
+ self.model: NEOChatModel = NEOChatModel.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True,
+ ).to(self.device)
+
+ self.model.eval()
+
+ def _denorm(self, x: torch.Tensor, mean=NORM_MEAN, std=NORM_STD):
+ """
+ x: [B,3,H,W] normalized ((img-mean)/std). returns [0,1] clamped.
+ """
+ mean = torch.tensor(mean, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
+ std = torch.tensor(std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
+ return (x * std + mean).clamp(0, 1)
+
+ def _get_dynamic_cache(self, past_kv):
+ """
+ past_kv ( L, 2, H // 2, P * S, D)
+ """
+ past_kv_dc = DynamicCache(config=self.model.language_model.model.config)
+ L, _, H, S, D = past_kv.shape
+ for layer_idx in range(L):
+ k = past_kv[layer_idx][0].unsqueeze(0).to(self.device, non_blocking=True)
+ v = past_kv[layer_idx][1].unsqueeze(0).to(self.device, non_blocking=True)
+ past_kv_dc.update(key_states=k, value_states=v, layer_idx=layer_idx,)
+ return past_kv_dc
+
+
+ def t2i(self, past_kv, past_kv_txt, param: X2IParams):
+ past_kv_dc = self._get_dynamic_cache(past_kv)
+ past_kv_txt_dc = self._get_dynamic_cache(past_kv_txt)
+ text_lens = (param.past_kvcache.get_compressed_len(),
+ param.past_kvcache_text.get_compressed_len())
+ output = self.model.t2i_generate(
+ past_key_values_condition=past_kv_dc,
+ past_key_values_uncondition=past_kv_txt_dc,
+ text_lens=text_lens,
+ cfg_norm=param.get_cfg_norm(),
+ cfg_scale=param.guidance_scale,
+ image_size=(param.width, param.height),
+ num_steps=param.steps,
+ batch_size=param.num_images)
+
+ return self._post_process(output)
+
+ def _post_process(self, output):
+ images = self._denorm(output)
+ images = (images.clamp(0, 1) * 255.0).round().to(torch.uint8).cpu()
+
+ base64_images = [
+ base64.b64encode(io.encode_jpeg(img).numpy()).decode("utf-8")
+ for img in images
+ ]
+ return base64_images
+
+ def it2i(self, past_kv, past_kv_txt, past_kv_img, param: X2IParams):
+ past_kv_dc = self._get_dynamic_cache(past_kv)
+ past_kv_txt_dc = self._get_dynamic_cache(past_kv_txt)
+ past_kv_img_dc = self._get_dynamic_cache(past_kv_img)
+ text_lens = (param.past_kvcache.get_compressed_len(),
+ param.past_kvcache_text.get_compressed_len(),
+ param.past_kvcache_img.get_compressed_len())
+ output = self.model.it2i_generate(
+ past_key_values_condition=past_kv_dc,
+ past_key_values_text_uncondition=past_kv_txt_dc,
+ past_key_values_img_uncondition=past_kv_img_dc,
+ text_lens=text_lens,
+ cfg_norm=param.get_cfg_norm(),
+ cfg_scale=param.guidance_scale,
+ img_cfg_scale=param.image_guidance_scale,
+ image_size=(param.width, param.height),
+ num_steps=param.steps,
+ batch_size=param.num_images,
+ )
+
+ return self._post_process(output)
diff --git a/lightllm/server/x2i_server/naive/modeling_neo_vit.py b/lightllm/server/x2i_server/naive/modeling_neo_vit.py
new file mode 100644
index 0000000000..63ded83e7d
--- /dev/null
+++ b/lightllm/server/x2i_server/naive/modeling_neo_vit.py
@@ -0,0 +1,235 @@
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from transformers.modeling_outputs import BaseModelOutputWithPooling
+from transformers.modeling_utils import PreTrainedModel
+
+from .configuration_neo_vit import NEOVisionConfig
+
+
+def precompute_rope_freqs_sincos(
+ dim: int, max_position: int, base: float = 10000.0, device=None
+):
+ """预计算 RoPE 的 cos 和 sin 值 (1D)。"""
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
+ t = torch.arange(max_position, device=device).type_as(inv_freq)
+ freqs = torch.outer(t, inv_freq)
+ return torch.cos(freqs), torch.sin(freqs)
+
+
+def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
+ """
+ Compute patch coordinates (x, y)
+
+ Args:
+ grid_hw: (B, 2) tensor representing (H, W) per image
+ """
+ device = grid_hw.device
+ B = grid_hw.shape[0]
+
+ # Get the number of patches per image
+ H = grid_hw[:, 0]
+ W = grid_hw[:, 1]
+ N = H * W
+ N_total = N.sum()
+
+ # Create the batch index for each patch (B x patch count)
+ patch_to_sample = torch.repeat_interleave(torch.arange(B, device=device), N) # (N_total,)
+
+ # Generate intra-image patch index (row-major order)
+ patch_id_within_image = torch.arange(N_total, device=device)
+ patch_id_within_image = patch_id_within_image - torch.cumsum(
+ torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0
+ )[patch_to_sample]
+
+ # Get H/W for each patch according to its image
+ W_per_patch = W[patch_to_sample]
+ abs_x = patch_id_within_image % W_per_patch
+ abs_y = patch_id_within_image // W_per_patch
+
+ return abs_x, abs_y
+
+
+def apply_rotary_emb_1d(
+ x: torch.Tensor,
+ cos_cached: torch.Tensor,
+ sin_cached: torch.Tensor,
+ positions: torch.Tensor,
+):
+ """对输入张量的一部分应用1D RoPE。"""
+ # x: (..., seq_len, dim_part)
+ # positions: (..., seq_len)
+ # cos_cached: (max_pos, dim_part / 2)
+
+ cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2)
+ sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2)
+
+ x1 = x[..., 0::2]
+ x2 = x[..., 1::2]
+
+ rotated_x1 = x1 * cos - x2 * sin
+ rotated_x2 = x1 * sin + x2 * cos
+
+ x_rotated = torch.empty_like(x)
+ x_rotated[..., 0::2] = rotated_x1
+ x_rotated[..., 1::2] = rotated_x2
+ return x_rotated
+
+
+def apply_2d_rotary_pos_emb(
+ x: torch.Tensor,
+ cos_cached_x: torch.Tensor,
+ sin_cached_x: torch.Tensor,
+ cos_cached_y: torch.Tensor,
+ sin_cached_y: torch.Tensor,
+ abs_positions_x: torch.Tensor,
+ abs_positions_y: torch.Tensor
+):
+ """应用2D RoPE到输入张量x。"""
+ dim = x.shape[-1]
+ dim_half = dim // 2
+
+ # 假设我们将embedding的前半部分用于一个方向的RoPE,后半部分用于另一个方向
+ # 例如,前一半给X坐标,后一半给Y坐标 (或者反过来,但要保持一致)
+ x_part_1 = x[..., :dim_half]
+ x_part_2 = x[..., dim_half:]
+
+ # 将与 abs_positions_x 相关的旋转应用于 x_part_1
+ rotated_part_1 = apply_rotary_emb_1d(
+ x_part_1, cos_cached_x, sin_cached_x, abs_positions_x
+ )
+ # 将与 abs_positions_y 相关的旋转应用于 x_part_2
+ rotated_part_2 = apply_rotary_emb_1d(
+ x_part_2, cos_cached_y, sin_cached_y, abs_positions_y
+ )
+
+ # 将它们重新拼接起来。确保顺序与你分割时一致。
+ return torch.cat((rotated_part_1, rotated_part_2), dim=-1)
+
+
+class NEOVisionEmbeddings(nn.Module):
+ """
+ Embedding Module for Vision.
+ """
+
+ def __init__(self, config: NEOVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.llm_embed_dim = config.llm_hidden_size[0]
+ self.downsample_factor = int(1 / config.downsample_ratio[0])
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
+ )
+ self.dense_embedding = nn.Conv2d(
+ in_channels=self.embed_dim, out_channels=self.llm_embed_dim, kernel_size=self.downsample_factor, stride=self.downsample_factor
+ )
+ self.gelu = nn.GELU()
+
+ self.rope_dim_part = self.embed_dim // 2
+ cos_x, sin_x = precompute_rope_freqs_sincos(
+ self.rope_dim_part, config.max_position_embeddings_vision, base=config.rope_theta_vision, device=None
+ )
+ cos_y, sin_y = precompute_rope_freqs_sincos(
+ self.rope_dim_part, config.max_position_embeddings_vision, base=config.rope_theta_vision, device=None
+ )
+
+ self.register_buffer("cos_cached_x", cos_x, persistent=False)
+ self.register_buffer("sin_cached_x", sin_x, persistent=False)
+ self.register_buffer("cos_cached_y", cos_y, persistent=False)
+ self.register_buffer("sin_cached_y", sin_y, persistent=False)
+
+ def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw):
+ """
+ Apply 2D Rotary Position Embedding to the patch embeddings.
+ """
+ abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device)
+ embeddings = apply_2d_rotary_pos_emb(
+ patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32
+ self.cos_cached_x, self.sin_cached_x,
+ self.cos_cached_y, self.sin_cached_y,
+ abs_pos_x,
+ abs_pos_y
+ ).to(self.patch_embedding.weight.dtype)
+ return embeddings
+
+ def forward(self, pixel_values: torch.FloatTensor, grid_hw=None) -> torch.Tensor:
+
+ pixel_values = pixel_values.view( #
+ -1,
+ 3,
+ self.patch_size,
+ self.patch_size,
+ ) # [28072, 768] -> [28072, 3, 16, 16]
+ patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim)
+ self.cos_cached_x = self.cos_cached_x.to(patch_embeds.device)
+ self.sin_cached_x = self.sin_cached_x.to(patch_embeds.device)
+ self.cos_cached_y = self.cos_cached_y.to(patch_embeds.device)
+ self.sin_cached_y = self.sin_cached_y.to(patch_embeds.device)
+ patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) # [28072, 1024]
+ assert (grid_hw[:,0] * grid_hw[:,1]).sum() == patch_embeds.shape[0]
+
+ patches_list = []
+ cur_position = 0
+ for i in range(grid_hw.shape[0]):
+ h, w = grid_hw[i]
+ patches_per_img = patch_embeds[cur_position : cur_position + h * w].view(h, w, -1).unsqueeze(0)
+ patches_per_img = self.dense_embedding(patches_per_img.permute(0, 3, 1, 2))
+ patches_per_img = patches_per_img.permute(0, 2, 3, 1)
+ patches_list.append(patches_per_img.view(-1, patches_per_img.shape[-1]))
+ cur_position += h * w
+
+ embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C)
+
+ assert cur_position == patch_embeds.shape[0]
+ assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor**2)
+
+ return embeddings
+
+
+class NEOVisionModel(PreTrainedModel):
+ main_input_name = 'pixel_values'
+ _supports_flash_attn_2 = True
+ supports_gradient_checkpointing = True
+ config_class = NEOVisionConfig
+ # support transformers 4.51.+
+ _tp_plan = ''
+
+ def __init__(self, config: NEOVisionConfig):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = NEOVisionEmbeddings(config)
+
+ def forward(
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_embeds: Optional[torch.FloatTensor] = None,
+ grid_hw: Optional[torch.Tensor] = None
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if pixel_values is None and pixel_embeds is None:
+ raise ValueError('You have to specify pixel_values or pixel_embeds')
+
+ if pixel_embeds is not None:
+ hidden_states = pixel_embeds
+ else:
+ assert pixel_values.dim() == 2, f"pixel_values must be 2D for native resolution, got: {pixel_values.dim()}"
+ hidden_states = self.embeddings(pixel_values, grid_hw=grid_hw)
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=hidden_states,
+ pooler_output=None,
+ hidden_states=None,
+ attentions=None,
+ )
diff --git a/lightllm/server/x2i_server/naive/modeling_qwen3.py b/lightllm/server/x2i_server/naive/modeling_qwen3.py
new file mode 100644
index 0000000000..5a4c5acc40
--- /dev/null
+++ b/lightllm/server/x2i_server/naive/modeling_qwen3.py
@@ -0,0 +1,1117 @@
+from typing import Callable, Optional, Union
+
+import torch
+import torch._dynamo
+from torch import nn
+
+import copy
+import math
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.generation import GenerationMixin
+from transformers.integrations import use_kernel_forward_from_hub
+from transformers.masking_utils import create_causal_mask
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.modeling_layers import (
+ GenericForQuestionAnswering,
+ GenericForSequenceClassification,
+ GenericForTokenClassification,
+ GradientCheckpointingLayer,
+)
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
+from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from transformers.processing_utils import Unpack
+from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
+from transformers.utils.deprecation import deprecate_kwarg
+from transformers.utils.generic import check_model_inputs
+from transformers import Qwen3Config
+
+from flash_attn import flash_attn_func
+
+
+def create_block_causal_mask(index: torch.Tensor):
+ """
+ index: (L)
+ return: (1, 1, L, L) block-wise causal attention mask
+ """
+ L = index.size(0)
+ idx_i = index.unsqueeze(1).expand(L, L)
+ idx_j = index.unsqueeze(0).expand(L, L)
+
+ arange = torch.arange(L, device=index.device)
+ mask = (idx_j == idx_i) | (arange.unsqueeze(0) <= arange.unsqueeze(1))
+
+ return torch.where(mask[None, None, :, :] > 0, torch.tensor(0.0), torch.tensor(float('-inf')))
+
+
+def visualize_mask(mask: torch.Tensor, i: int = 0, j: int = 12):
+ """
+ mask: (1,1, L, L)
+ """
+ submask = torch.where(mask[0, 0, :, :] == 0, torch.tensor(1.0), torch.tensor(0.0))
+ submask = mask[i:j, i:j].int().cpu().numpy()
+ for row in submask:
+ print(" ".join(map(str, row)))
+
+
+@use_kernel_forward_from_hub("RMSNorm")
+class Qwen3RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
+ """
+ Qwen3RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Qwen3MLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: float,
+ dropout: float = 0.0,
+ **kwargs: Unpack[TransformersKwargs],
+):
+ key_states = repeat_kv(key, module.num_key_value_groups)
+ value_states = repeat_kv(value, module.num_key_value_groups)
+
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+class Qwen3RotaryEmbedding(nn.Module):
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
+
+ def __init__(self, config: Qwen3Config, device=None):
+ super().__init__()
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ base_rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+
+ def _rope_init_fn_keep_freq_range(cfg: Qwen3Config, dev=None):
+ inv_freq, attention_scaling = base_rope_init_fn(cfg, dev)
+
+ cfg2 = copy.deepcopy(cfg)
+ head_dim = getattr(cfg2, "head_dim", None)
+ if head_dim is None:
+ head_dim = cfg2.hidden_size // cfg2.num_attention_heads
+ setattr(cfg2, "head_dim", head_dim)
+ cfg2.head_dim = int(head_dim) * 2
+
+ inv_freq_full, _ = base_rope_init_fn(cfg2, dev)
+ inv_freq = inv_freq_full[::2]
+
+ return inv_freq, attention_scaling
+
+ self.rope_init_fn = _rope_init_fn_keep_freq_range
+
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ @torch.no_grad()
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
+ def forward(self, x, position_ids):
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos() * self.attention_scaling
+ sin = emb.sin() * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+class Qwen3Attention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config: Qwen3Config, layer_idx: int):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
+ self.scaling = self.head_dim**-0.5
+ self.attention_dropout = config.attention_dropout
+ self.is_causal = True
+
+ self.q_proj = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.q_proj_mot_gen = nn.Linear(
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
+ )
+
+ self.k_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.k_proj_mot_gen = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+
+ self.v_proj = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.v_proj_mot_gen = nn.Linear(
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
+ )
+ self.o_proj = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+ self.o_proj_mot_gen = nn.Linear(
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
+ )
+
+ self.q_norm = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
+ self.q_norm_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps)
+ self.q_norm_hw = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps)
+ self.q_norm_hw_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps)
+
+ self.k_norm = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
+ self.k_norm_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps)
+ self.k_norm_hw = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
+ self.k_norm_hw_mot_gen = Qwen3RMSNorm(self.head_dim // 2, eps=config.rms_norm_eps)
+
+ self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
+
+ t_config = copy.deepcopy(config)
+ t_config.head_dim = config.head_dim // 2
+ self.rotary_emb = Qwen3RotaryEmbedding(config=t_config)
+
+ hw_config = copy.deepcopy(config)
+ hw_config.head_dim = config.head_dim // 4
+ hw_config.rope_theta = config.rope_theta_hw
+ hw_config.max_position_embeddings = config.max_position_embeddings_hw
+ self.rotary_emb_hw = Qwen3RotaryEmbedding(config=hw_config)
+
+ def forward_und(
+ self,
+ hidden_states: torch.Tensor,
+ indexes: Optional[torch.LongTensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ assert self.config._attn_implementation == "eager"
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
+ query_states_t, query_states_hw = query_states.chunk(2, dim=-1)
+ query_states_t = self.q_norm(query_states_t).transpose(1, 2)
+ query_states_hw = self.q_norm_hw(query_states_hw).transpose(1, 2)
+ query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1)
+
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
+ key_states_t, key_states_hw = key_states.chunk(2, dim=-1)
+ key_states_t = self.k_norm(key_states_t).transpose(1, 2)
+ key_states_hw = self.k_norm_hw(key_states_hw).transpose(1, 2)
+ key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1)
+
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0))
+ query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t)
+
+ cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0))
+ query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h)
+
+ cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0))
+ query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w)
+
+ query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1)
+ key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1)
+
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ # key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+ update_cache = kwargs.get("update_cache", True)
+ if update_cache:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs=None)
+ else:
+ # only use the past key values but do not append the current one
+ layer = past_key_values.layers[self.layer_idx]
+ past_k, past_v = layer.keys, layer.values
+
+ if past_k is not None:
+ key_states = torch.cat([past_k, key_states], dim=2) # concat on seq_len
+ value_states = torch.cat([past_v, value_states], dim=2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window, # diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights
+
+ # def forward_gen(
+ # self,
+ # hidden_states: torch.Tensor,
+ # indexes: Optional[torch.LongTensor],
+ # attention_mask: Optional[torch.Tensor],
+ # past_key_values: Optional[Cache] = None,
+ # cache_position: Optional[torch.LongTensor] = None,
+ # **kwargs: Unpack[FlashAttentionKwargs],
+ # ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ # assert self.config._attn_implementation == "eager"
+ # input_shape = hidden_states.shape[:-1]
+ # hidden_shape = (*input_shape, -1, self.head_dim)
+
+ # query_states = self.q_proj_mot_gen(hidden_states).view(hidden_shape)
+ # query_states_t, query_states_hw = query_states.chunk(2, dim=-1)
+ # query_states_t = self.q_norm_mot_gen(query_states_t).transpose(1, 2)
+ # query_states_hw = self.q_norm_hw_mot_gen(query_states_hw).transpose(1, 2)
+ # query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1)
+
+ # key_states = self.k_proj_mot_gen(hidden_states).view(hidden_shape)
+ # key_states_t, key_states_hw = key_states.chunk(2, dim=-1)
+ # key_states_t = self.k_norm_mot_gen(key_states_t).transpose(1, 2)
+ # key_states_hw = self.k_norm_hw_mot_gen(key_states_hw).transpose(1, 2)
+ # key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1)
+
+ # value_states = self.v_proj_mot_gen(hidden_states).view(hidden_shape).transpose(1, 2)
+
+ # cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0))
+ # query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t)
+
+ # cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0))
+ # query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h)
+
+ # cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0))
+ # query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w)
+
+ # query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1)
+ # key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1)
+
+
+ # if past_key_values is not None:
+ # # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ # # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ # # key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+ # update_cache = kwargs.get("update_cache", True)
+ # if update_cache:
+ # key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs=None)
+ # else:
+ # # only use the past key values but do not append the current one
+ # layer = past_key_values.layers[self.layer_idx]
+ # past_k, past_v = layer.keys, layer.values
+
+ # if past_k is not None:
+ # key_states = torch.cat([past_k, key_states], dim=2) # concat on seq_len
+ # value_states = torch.cat([past_v, value_states], dim=2)
+
+ # attention_interface: Callable = eager_attention_forward
+ # if self.config._attn_implementation != "eager":
+ # attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ # attn_output, attn_weights = attention_interface(
+ # self,
+ # query_states,
+ # key_states,
+ # value_states,
+ # attention_mask,
+ # dropout=0.0 if not self.training else self.attention_dropout,
+ # scaling=self.scaling,
+ # sliding_window=self.sliding_window, # diff with Llama
+ # **kwargs,
+ # )
+
+ # attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ # attn_output = self.o_proj_mot_gen(attn_output)
+ # return attn_output, attn_weights
+
+ def forward_gen(
+ self,
+ hidden_states: torch.Tensor,
+ indexes: Optional[torch.LongTensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ # -----------------------------
+ # Build q / k / v for current tokens
+ # Internal layout before flash:
+ # q/k/v: [B, H, S, D]
+ # Flash layout:
+ # q/k/v: [B, S, H, D]
+ # -----------------------------
+ query_states = self.q_proj_mot_gen(hidden_states).view(hidden_shape)
+ query_states_t, query_states_hw = query_states.chunk(2, dim=-1)
+ query_states_t = self.q_norm_mot_gen(query_states_t).transpose(1, 2) # [B,H,S,D/2]
+ query_states_hw = self.q_norm_hw_mot_gen(query_states_hw).transpose(1, 2)
+ query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1)
+
+ key_states = self.k_proj_mot_gen(hidden_states).view(hidden_shape)
+ key_states_t, key_states_hw = key_states.chunk(2, dim=-1)
+ key_states_t = self.k_norm_mot_gen(key_states_t).transpose(1, 2) # [B,H,S,D/2]
+ key_states_hw = self.k_norm_hw_mot_gen(key_states_hw).transpose(1, 2)
+ key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1)
+
+ value_states = self.v_proj_mot_gen(hidden_states).view(hidden_shape).transpose(1, 2) # [B,H,S,D]
+
+ # RoPE
+ cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0))
+ query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t)
+
+ cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0))
+ query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h)
+
+ cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0))
+ query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w)
+
+ # concat along head_dim
+ # query/key current layout: [B, H, S, D]
+ query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1)
+ key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1)
+
+ update_cache = kwargs.get("update_cache", True)
+
+ # ------------------------------------------------------------------
+ # Flash path:
+ # Only use when there is no explicit dense mask.
+ # This is exactly the t2i denoising use case:
+ # current image tokens attend to [prefix + current image tokens]
+ # fully bidirectional inside current block => causal=False
+ # ------------------------------------------------------------------
+ if attention_mask is None:
+ # Convert current q/k/v to flash layout [B, S, H, D]
+ q = query_states.transpose(1, 2).contiguous()
+ k_cur = key_states.transpose(1, 2).contiguous()
+ v_cur = value_states.transpose(1, 2).contiguous()
+
+ if past_key_values is not None:
+ if update_cache:
+ # Rare path, keep compatibility.
+ # past_key_values.update expects [B,H,S,D]
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, cache_kwargs=None
+ )
+ k = key_states.transpose(1, 2).contiguous()
+ v = value_states.transpose(1, 2).contiguous()
+ else:
+ # Optimized path:
+ # use preallocated flash_k_cache / flash_v_cache
+ layer = past_key_values.layers[self.layer_idx]
+
+ if (
+ hasattr(layer, "flash_k_cache")
+ and layer.flash_k_cache is not None
+ and hasattr(layer, "flash_v_cache")
+ and layer.flash_v_cache is not None
+ ):
+ prefix_len = layer.flash_prefix_len
+ cur_len = k_cur.shape[1]
+
+ # overwrite current segment in-place
+ layer.flash_k_cache[:, prefix_len:prefix_len + cur_len].copy_(k_cur)
+ layer.flash_v_cache[:, prefix_len:prefix_len + cur_len].copy_(v_cur)
+
+ k = layer.flash_k_cache[:, :prefix_len + cur_len]
+ v = layer.flash_v_cache[:, :prefix_len + cur_len]
+ else:
+ # fallback if user forgot to prepare flash cache
+ layer = past_key_values.layers[self.layer_idx]
+ past_k, past_v = layer.keys, layer.values
+
+ if past_k is not None:
+ past_k = past_k.transpose(1, 2).contiguous()
+ past_v = past_v.transpose(1, 2).contiguous()
+ k = torch.cat([past_k, k_cur], dim=1)
+ v = torch.cat([past_v, v_cur], dim=1)
+ else:
+ k = k_cur
+ v = v_cur
+ else:
+ k = k_cur
+ v = v_cur
+
+ # sanity checks
+ assert q.ndim == 4 and k.ndim == 4 and v.ndim == 4
+ assert q.shape[0] == k.shape[0] == v.shape[0], (q.shape, k.shape, v.shape)
+ assert k.shape[1] == v.shape[1], (k.shape, v.shape)
+ assert k.shape[2] == v.shape[2], (k.shape, v.shape)
+ assert q.shape[3] == k.shape[3] == v.shape[3], (q.shape, k.shape, v.shape)
+
+ attn_output = flash_attn_func(
+ q,
+ k,
+ v,
+ dropout_p=0.0 if not self.training else self.attention_dropout,
+ softmax_scale=self.scaling,
+ causal=False,
+ ) # [B, S_q, H_q, D]
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj_mot_gen(attn_output)
+ return attn_output, None
+
+ # ------------------------------------------------------------------
+ # Original eager fallback path
+ # ------------------------------------------------------------------
+ if past_key_values is not None:
+ if update_cache:
+ key_states, value_states = past_key_values.update(
+ key_states, value_states, self.layer_idx, cache_kwargs=None
+ )
+ else:
+ layer = past_key_values.layers[self.layer_idx]
+ past_k, past_v = layer.keys, layer.values
+ if past_k is not None:
+ key_states = torch.cat([past_k, key_states], dim=2)
+ value_states = torch.cat([past_v, value_states], dim=2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window,
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+ attn_output = self.o_proj_mot_gen(attn_output)
+ return attn_output, attn_weights
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ image_gen_indicators: torch.Tensor,
+ exist_non_image_gen_tokens: bool,
+ exist_image_gen_tokens: bool,
+ indexes: Optional[torch.LongTensor],
+ attention_mask: Optional[torch.Tensor],
+ past_key_values: Optional[Cache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[FlashAttentionKwargs],
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
+ if exist_non_image_gen_tokens and not exist_image_gen_tokens:
+ return self.forward_und(hidden_states, indexes, attention_mask, past_key_values, cache_position, **kwargs)
+ if not exist_non_image_gen_tokens and exist_image_gen_tokens:
+ return self.forward_gen(hidden_states, indexes, attention_mask, past_key_values, cache_position, **kwargs)
+
+ assert self.config._attn_implementation == "eager"
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = hidden_states.new_zeros((*input_shape, self.config.num_attention_heads*self.head_dim))
+ if exist_non_image_gen_tokens:
+ query_states[~image_gen_indicators] = self.q_proj(hidden_states[~image_gen_indicators])
+ if exist_image_gen_tokens:
+ query_states[image_gen_indicators] = self.q_proj_mot_gen(hidden_states[image_gen_indicators])
+ query_states_t, query_states_hw = query_states.chunk(2, dim=-1)
+
+ _query_states_hw = query_states_hw.new_zeros(query_states_hw.shape)
+ if exist_non_image_gen_tokens:
+ _query_states_hw[~image_gen_indicators] = self.q_norm_hw(query_states_hw[~image_gen_indicators])
+ if exist_image_gen_tokens:
+ _query_states_hw[image_gen_indicators] = self.q_norm_hw_mot_gen(query_states_h[image_gen_indicators])
+ query_states_hw = _query_states_hw.transpose(1, 2)
+ query_states_h, query_states_w = query_states_hw.chunk(2, dim=-1)
+
+ key_states = hidden_states.new_zeros((*input_shape, self.config.num_key_value_heads*self.head_dim))
+ if exist_non_image_gen_tokens:
+ key_states[~image_gen_indicators] = self.k_proj(hidden_states[~image_gen_indicators])
+ if exist_image_gen_tokens:
+ key_states[image_gen_indicators] = self.k_proj_mot_gen(hidden_states[image_gen_indicators])
+ key_states_t, key_states_hw = key_states.chunk(2, dim=-1)
+
+ _key_states_hw = key_states_hw.new_zeros(key_states_hw.shape)
+ if exist_non_image_gen_tokens:
+ _key_states_hw[~image_gen_indicators] = self.k_norm_hw(key_states_hw[~image_gen_indicators])
+ if exist_image_gen_tokens:
+ _key_states_hw[image_gen_indicators] = self.k_norm_hw_mot_gen(key_states_h[image_gen_indicators])
+ key_states_hw = _key_states_hw.transpose(1, 2)
+ key_states_h, key_states_w = key_states_hw.chunk(2, dim=-1)
+
+ value_states = hidden_states.new_zeros((*input_shape, self.config.num_key_value_heads*self.head_dim))
+ if exist_non_image_gen_tokens:
+ value_states[~image_gen_indicators] = self.v_proj(hidden_states[~image_gen_indicators])
+ if exist_image_gen_tokens:
+ value_states[image_gen_indicators] = self.v_proj_mot_gen(hidden_states[image_gen_indicators])
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
+
+ cos_t, sin_t = self.rotary_emb(hidden_states, indexes[0].unsqueeze(0))
+ query_states_t, key_states_t = apply_rotary_pos_emb(query_states_t, key_states_t, cos_t, sin_t)
+
+ cos_h, sin_h = self.rotary_emb_hw(hidden_states, indexes[1].unsqueeze(0))
+ query_states_h, key_states_h = apply_rotary_pos_emb(query_states_h, key_states_h, cos_h, sin_h)
+
+ cos_w, sin_w = self.rotary_emb_hw(hidden_states, indexes[2].unsqueeze(0))
+ query_states_w, key_states_w = apply_rotary_pos_emb(query_states_w, key_states_w, cos_w, sin_w)
+
+ query_states = torch.cat([query_states_t, query_states_h, query_states_w], dim=-1)
+ key_states = torch.cat([key_states_t, key_states_h, key_states_w], dim=-1)
+
+
+ if past_key_values is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ # key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+ update_cache = kwargs.get("update_cache", True)
+ if update_cache:
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs=None)
+ else:
+ # only use the past key values but do not append the current one
+ layer = past_key_values.layers[self.layer_idx]
+ past_k, past_v = layer.keys, layer.values
+
+ if past_k is not None:
+ key_states = torch.cat([past_k, key_states], dim=2) # concat on seq_len
+ value_states = torch.cat([past_v, value_states], dim=2)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.attention_dropout,
+ scaling=self.scaling,
+ sliding_window=self.sliding_window, # diff with Llama
+ **kwargs,
+ )
+
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
+
+ _attn_output = attn_output.new_zeros((*input_shape, self.config.hidden_size))
+ if exist_non_image_gen_tokens:
+ _attn_output[~image_gen_indicators] = self.o_proj(attn_output[~image_gen_indicators])
+ if exist_image_gen_tokens:
+ _attn_output[image_gen_indicators] = self.o_proj_mot_gen(attn_output[image_gen_indicators])
+
+ attn_output = _attn_output
+ return attn_output, attn_weights
+
+
+class Qwen3DecoderLayer(GradientCheckpointingLayer):
+ def __init__(self, config: Qwen3Config, layer_idx: int):
+ super().__init__()
+ self.hidden_size = config.hidden_size
+
+ self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
+
+ self.mlp = Qwen3MLP(config)
+ self.mlp_mot_gen = Qwen3MLP(config)
+ self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.input_layernorm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.attention_type = config.layer_types[layer_idx]
+
+ def forward_und(
+ self,
+ hidden_states: torch.Tensor,
+ image_gen_indicators: torch.Tensor,
+ exist_non_image_gen_tokens: bool,
+ exist_image_gen_tokens: bool,
+ indexes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ image_gen_indicators=image_gen_indicators,
+ exist_non_image_gen_tokens=exist_non_image_gen_tokens,
+ exist_image_gen_tokens=exist_image_gen_tokens,
+ indexes=indexes,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+ def forward_gen(
+ self,
+ hidden_states: torch.Tensor,
+ image_gen_indicators: torch.Tensor,
+ exist_non_image_gen_tokens: bool,
+ exist_image_gen_tokens: bool,
+ indexes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.input_layernorm_mot_gen(hidden_states)
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ image_gen_indicators=image_gen_indicators,
+ exist_non_image_gen_tokens=exist_non_image_gen_tokens,
+ exist_image_gen_tokens=exist_image_gen_tokens,
+ indexes=indexes,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm_mot_gen(hidden_states)
+ hidden_states = self.mlp_mot_gen(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ image_gen_indicators: torch.Tensor,
+ exist_non_image_gen_tokens: bool,
+ exist_image_gen_tokens: bool,
+ indexes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> torch.Tensor:
+ if exist_non_image_gen_tokens and not exist_image_gen_tokens:
+ return self.forward_und(hidden_states, image_gen_indicators, exist_non_image_gen_tokens, exist_image_gen_tokens, indexes, attention_mask, position_ids, past_key_values, use_cache, cache_position, **kwargs)
+ if not exist_non_image_gen_tokens and exist_image_gen_tokens:
+ return self.forward_gen(hidden_states, image_gen_indicators, exist_non_image_gen_tokens, exist_image_gen_tokens, indexes, attention_mask, position_ids, past_key_values, use_cache, cache_position, **kwargs)
+
+ residual = hidden_states
+
+ _hidden_states = hidden_states.new_zeros(hidden_states.shape)
+ if exist_non_image_gen_tokens:
+ _hidden_states[~image_gen_indicators] = self.input_layernorm(hidden_states[~image_gen_indicators])
+ if exist_image_gen_tokens:
+ _hidden_states[image_gen_indicators] = self.input_layernorm_mot_gen(hidden_states[image_gen_indicators])
+ hidden_states = _hidden_states
+
+ # Self Attention
+ hidden_states, _ = self.self_attn(
+ hidden_states=hidden_states,
+ image_gen_indicators=image_gen_indicators,
+ exist_non_image_gen_tokens=exist_non_image_gen_tokens,
+ exist_image_gen_tokens=exist_image_gen_tokens,
+ indexes=indexes,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+
+ _hidden_states = hidden_states.new_zeros(hidden_states.shape)
+ if exist_non_image_gen_tokens:
+ _hidden_states[~image_gen_indicators] = self.mlp(self.post_attention_layernorm(hidden_states[~image_gen_indicators]))
+
+ if exist_image_gen_tokens:
+ _hidden_states[image_gen_indicators] = self.mlp_mot_gen(self.post_attention_layernorm_mot_gen(hidden_states[image_gen_indicators]))
+
+ hidden_states = _hidden_states
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+
+@auto_docstring
+class Qwen3PreTrainedModel(PreTrainedModel):
+ config: Qwen3Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen3DecoderLayer"]
+ _skip_keys_device_placement = ["past_key_values"]
+ _supports_flash_attn = True
+ _supports_sdpa = True
+ _supports_flex_attn = True
+
+ _can_compile_fullgraph = True
+ _supports_attention_backend = True
+ _can_record_outputs = {
+ "hidden_states": Qwen3DecoderLayer,
+ "attentions": Qwen3Attention,
+ }
+
+
+@auto_docstring
+class Qwen3Model(Qwen3PreTrainedModel):
+ def __init__(self, config: Qwen3Config):
+ super().__init__(config)
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ self.layers = nn.ModuleList(
+ [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.norm_mot_gen = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ self.gradient_checkpointing = False
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
+ self.current_index = -1
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @check_model_inputs
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ image_gen_indicators: Optional[torch.Tensor] = None,
+ indexes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> BaseModelOutputWithPast:
+
+ # assert position_ids is not None
+ # assert cache_position is not None
+ # assert past_key_values is not None
+
+ if image_gen_indicators is None:
+ exist_non_image_gen_tokens = True
+ exist_image_gen_tokens = False
+ else:
+ exist_non_image_gen_tokens = (~image_gen_indicators).any()
+ exist_image_gen_tokens = image_gen_indicators.any()
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if use_cache and past_key_values is None:
+ past_key_values = DynamicCache(config=self.config)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # It may already have been prepared by e.g. `generate`
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
+ # Prepare mask arguments
+ if input_ids is not None:
+ mask_kwargs = {
+ "config": self.config,
+ "input_embeds": inputs_embeds,
+ "attention_mask": attention_mask,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ # Create the masks
+ causal_mask_mapping = {
+ "full_attention": create_causal_mask(**mask_kwargs),
+ }
+ self.current_index += 1
+ indexes = torch.LongTensor([[self.current_index], [0], [0]]).to(input_ids.device)
+ else:
+ causal_mask_mapping = {
+ "full_attention": create_block_causal_mask(indexes[0]),
+ }
+ self.current_index = indexes[0].max()
+ else:
+ self.current_index = indexes[0].max()
+ # raise NotImplementedError('not isinstance(causal_mask_mapping := attention_mask, dict)')
+
+ # The sliding window alternating layers are not always activated depending on the config
+ # if self.has_sliding_layers:
+ # causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
+
+ hidden_states = inputs_embeds
+
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
+ hidden_states = decoder_layer(
+ hidden_states,
+ image_gen_indicators=image_gen_indicators,
+ exist_non_image_gen_tokens=exist_non_image_gen_tokens,
+ exist_image_gen_tokens=exist_image_gen_tokens,
+ indexes=indexes,
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+ if not exist_image_gen_tokens:
+ hidden_states = self.norm(hidden_states)
+ elif not exist_non_image_gen_tokens:
+ hidden_states = self.norm_mot_gen(hidden_states)
+ else:
+ _hidden_states = hidden_states.new_zeros(hidden_states.shape)
+ _hidden_states[~image_gen_indicators] = self.norm(hidden_states[~image_gen_indicators])
+ _hidden_states[image_gen_indicators] = self.norm_mot_gen(hidden_states[image_gen_indicators])
+ hidden_states = _hidden_states
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=past_key_values if use_cache else None,
+ )
+
+
+@auto_docstring
+class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["lm_head.weight"]
+ _tp_plan = {"lm_head": "colwise_rep"}
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = Qwen3Model(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ indexes: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> CausalLMOutputWithPast:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
+
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ outputs: BaseModelOutputWithPast = self.model(
+ input_ids=input_ids,
+ indexes=indexes,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class Qwen3ForSequenceClassification(GenericForSequenceClassification, Qwen3PreTrainedModel):
+ pass
+
+
+class Qwen3ForTokenClassification(GenericForTokenClassification, Qwen3PreTrainedModel):
+ pass
+
+
+class Qwen3ForQuestionAnswering(GenericForQuestionAnswering, Qwen3PreTrainedModel):
+ base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
+
+
+__all__ = [
+ "Qwen3ForCausalLM",
+ "Qwen3ForQuestionAnswering",
+ "Qwen3PreTrainedModel",
+ "Qwen3Model",
+ "Qwen3ForSequenceClassification",
+ "Qwen3ForTokenClassification",
+]
\ No newline at end of file
diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py
index 4a95635d1e..b06db8da3c 100644
--- a/lightllm/server/x2i_server/past_kv_cache_client.py
+++ b/lightllm/server/x2i_server/past_kv_cache_client.py
@@ -12,17 +12,11 @@
attach_shm_kv_cache_ptr,
register_shm_ptr_to_pin,
)
+from lightllm.server.core.objs.x2i_params import PastKVCacheItem
logger = init_logger(__name__)
-@dataclass
-class PastKVCacheItem:
- req_id: int
- token_len: int
- page_indexes: List[int]
-
-
class PastKVCacheClient(object):
"""
This class is responsible for passing kv cache between generation server and model server,
@@ -49,7 +43,7 @@ def __init__(self, only_create_meta_data: bool, init_shm_data: bool):
self.attach_shm_handle.wait()
return
- def allocate_pages(self, req_id: int, need_tokens: int) -> List[int]:
+ def allocate_pages(self, req_id: int, need_tokens: int, img_tokens: int = 0, img_len: int = 0) -> List[int]:
need_pages = (need_tokens + self.token_page_size - 1) // self.token_page_size
if need_pages > self.page_num:
logger.error(
@@ -64,7 +58,7 @@ def allocate_pages(self, req_id: int, need_tokens: int) -> List[int]:
page_indexes, self.free_pages = self.free_pages[:need_pages], self.free_pages[need_pages:]
self.allocated_pages_dict[req_id] = PastKVCacheItem(
- req_id=req_id, token_len=need_tokens, page_indexes=page_indexes)
+ req_id=req_id, token_len=need_tokens, page_indexes=page_indexes, img_tokens=img_tokens, img_len=img_len)
return page_indexes
@@ -78,7 +72,7 @@ def free_pages_by_req_id(self, req_id: int):
def get_pages_by_req_id(self, req_id: int) -> Optional[List[int]]:
with self.lock:
item = self.allocated_pages_dict.get(req_id, None)
- return item.page_indexes if item is not None else None
+ return item
def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optional[torch.Tensor]:
if page_indexes is None:
From a5a08447345203b7e9253f8ae03123799bd50abf Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Fri, 27 Mar 2026 16:32:18 +0000
Subject: [PATCH 05/41] interleave api.
---
lightllm/models/neo_chat_moe/model.py | 16 +-
lightllm/server/api_http.py | 12 +-
lightllm/server/api_models.py | 100 ++++++-
lightllm/server/api_openai.py | 258 ++++++++++++++++--
.../core/objs/token_chunck_hash_list.py | 3 +-
lightllm/server/httpserver/manager.py | 5 +-
lightllm/server/multimodal_params.py | 3 +
lightllm/server/x2i_server/manager.py | 2 +-
8 files changed, 358 insertions(+), 41 deletions(-)
diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py
index edabc06075..ec8772cc72 100644
--- a/lightllm/models/neo_chat_moe/model.py
+++ b/lightllm/models/neo_chat_moe/model.py
@@ -42,6 +42,7 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
self.image_start_id = tokenizer.convert_tokens_to_ids(self.image_start_tag)
self.image_end_tag = IMG_END_TOKEN
self.image_end_id = tokenizer.convert_tokens_to_ids(self.image_end_tag)
+ self.image_tag = IMG_TOKEN
self.conversation_module = self.load_conversion_module(tokenizer.name_or_path)
self.template = model_cfg.get("template", "neo1_0")
@@ -146,16 +147,19 @@ def fix_prompt(self, prompt: str, img_len: int):
def get_query_for_it2i(self, prompt: str):
image_len = prompt.count(IMG_TOKEN)
- query_condition = self._build_t2i_query(prompt, thinking_content="\n\n\n\n")
+ # query_condition = self._build_t2i_query(prompt, thinking_content="\n\n\n\n")
+ query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt
query_text_uncondition = self._build_t2i_query(IMG_TOKEN * image_len)
question_img_uncondition = self._build_t2i_query("")
return query_condition, query_text_uncondition, question_img_uncondition
- def get_query_for_t2i(self, prompt):
- query_condition = self._build_t2i_query(
- f"Please generate an image based on the following description: {prompt}",
- thinking_content="\n\n\n\n")
- query_uncondition = self._build_t2i_query(f"")
+ def get_query_for_t2i(self, prompt: str):
+ # prompt is already applied
+ query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt
+ # query_condition = self._build_t2i_query(
+ # f"Please generate an image based on the following description: {prompt}",
+ # thinking_content="\n\n\n\n")
+ query_uncondition = self._build_t2i_query("")
return query_condition, query_uncondition
diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py
index 4c366fb770..65571bf0fc 100755
--- a/lightllm/server/api_http.py
+++ b/lightllm/server/api_http.py
@@ -51,12 +51,13 @@
from lightllm.utils.envs_utils import get_unique_server_name
from dataclasses import dataclass
-from .api_openai import chat_completions_impl, completions_impl
+from .api_openai import chat_completions_impl, completions_impl, chat_completions_impl_v2
from .api_models import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
+ ChatCompletionRequestV2,
)
from .build_prompt import build_prompt, init_tokenizer
@@ -274,6 +275,15 @@ async def generate_image(request: Request) -> Response:
except Exception as e:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
+@app.post("/v2/chat/completions", response_model=ChatCompletionResponse)
+async def completions_v2(request: ChatCompletionRequestV2, raw_request: Request) -> Response:
+ if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
+ return create_error_response(
+ HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface"
+ )
+
+ resp = await chat_completions_impl_v2(request, raw_request)
+ return resp
@app.get("/tokens")
@app.post("/tokens")
diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py
index 3d9a6bc8ed..5725799338 100644
--- a/lightllm/server/api_models.py
+++ b/lightllm/server/api_models.py
@@ -3,7 +3,7 @@
import uuid
from pydantic import BaseModel, Field, field_validator, model_validator
-from typing import Any, Dict, List, Optional, Union, Literal, ClassVar
+from typing import Any, Dict, List, Optional, Union, Literal, ClassVar, TypeAlias
from transformers import GenerationConfig
@@ -275,7 +275,7 @@ class UsageInfo(BaseModel):
class ChatMessage(BaseModel):
role: Optional[str] = None
- content: Optional[str] = None
+ content: Optional[Union[str, List[MessageContent]]] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
@@ -301,7 +301,7 @@ def ensure_id_is_str(cls, v):
class DeltaMessage(BaseModel):
role: Optional[str] = None
- content: Optional[str] = None
+ content: Optional[Union[str, List[MessageContent]]] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
reasoning_content: Optional[str] = None
@@ -370,3 +370,97 @@ class CompletionStreamResponse(BaseModel):
@field_validator("id", mode="before")
def ensure_id_is_str(cls, v):
return str(v)
+
+
+# Supported values
+AspectRatio: TypeAlias = Literal[
+ "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4",
+ "9:16", "16:9", "21:9",
+]
+
+ImageSize: TypeAlias = Literal["0.5K", "1K", "2K", "4K"]
+
+Modality: TypeAlias = Literal["text", "image", "audio"]
+
+ImageType: TypeAlias = Literal["png", "jpeg", "webp"]
+
+class ImageConfig(BaseModel):
+ aspect_ratio: AspectRatio = "1:1"
+ image_size: ImageSize = "1K"
+ image_type: ImageType = "jpeg"
+
+ # Mapping to actual resolutions (base resolution for 1K)
+ _aspect_ratio_to_resolution: ClassVar[dict] = {
+ "1:1": (1024, 1024),
+ "2:3": (832, 1248),
+ "3:2": (1248, 832),
+ "3:4": (864, 1184),
+ "4:3": (1184, 864),
+ "4:5": (896, 1152),
+ "5:4": (1152, 896),
+ "9:16": (768, 1344),
+ "16:9": (1344, 768),
+ "21:9": (1536, 672),
+ }
+
+ _size_multiplier: ClassVar[dict] = {
+ "0.5K": 0.5,
+ "1K": 1.0,
+ "2K": 2.0,
+ "4K": 4.0,
+ }
+
+ @field_validator("aspect_ratio")
+ @classmethod
+ def validate_aspect_ratio(cls, v):
+ if v not in cls._aspect_ratio_to_resolution:
+ raise ValueError(f"Unsupported aspect ratio: {v}")
+ return v
+
+ @field_validator("image_size")
+ @classmethod
+ def validate_image_size(cls, v):
+ if v not in cls._size_multiplier:
+ raise ValueError(f"Unsupported image size: {v}")
+ return v
+
+ @field_validator("image_type")
+ @classmethod
+ def validate_image_type(cls, v):
+ if v not in ['jpeg', 'png', 'webp']:
+ raise ValueError(f"unsupported image type: {v}")
+ return v
+
+ def get_resolution(self):
+ """Return scaled resolution (width, height)"""
+ base = self._aspect_ratio_to_resolution[self.aspect_ratio]
+ if base is None:
+ return None # extended ratios don't have fixed base
+
+ scale = self._size_multiplier[self.image_size]
+ w, h = base
+ return int(w * scale), int(h * scale)
+
+
+class ChatCompletionRequestV2(ChatCompletionRequest):
+ modalities: List[Modality] = ["text"]
+ image_config: Optional[ImageConfig] = None
+
+ @field_validator("modalities")
+ @classmethod
+ def validate_modalities(cls, v):
+ if "text" not in v:
+ raise ValueError("modalities must include 'text'")
+ if len(v) != len(set(v)):
+ raise ValueError("modalities must be unique")
+ return v
+
+ @model_validator(mode="after")
+ def validate_image_config(self):
+ if "image" in self.modalities:
+ if self.image_config is None:
+ self.image_config = ImageConfig()
+ else:
+ if self.image_config is not None:
+ raise ValueError("image_config provided but 'image' not in modalities")
+ return self
\ No newline at end of file
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index 598bd8f1f2..e40bfaca7d 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -25,6 +25,7 @@
from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import Response, StreamingResponse, JSONResponse
from lightllm.server.core.objs.sampling_params import SamplingParams
+from lightllm.server.core.objs.x2i_params import X2IParams
from .multimodal_params import MultimodalParams
from .httpserver.manager import HttpServerManager
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
@@ -53,6 +54,9 @@
DeltaMessage,
ChatCompletionStreamResponse,
ChatCompletionStreamResponseChoice,
+ ChatCompletionRequestV2,
+ MessageContent,
+ ImageURL,
)
logger = init_logger(__name__)
@@ -159,38 +163,21 @@ def _process_tools_stream(index: int, delta: str, parser_dict: Dict, request: Ch
normal_text, calls = parser.parse_stream_chunk(delta)
return normal_text, calls
-
-async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response:
- from .api_http import g_objs
-
- if request.logit_bias is not None:
- return create_error_response(
- HTTPStatus.BAD_REQUEST,
- "The logit_bias parameter is not currently supported",
- )
-
- if request.function_call != "none":
- return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported")
-
- created_time = int(time.time())
-
- multimodal_params_dict = {"images": [], "audios": []}
+def _get_images_and_audios(request: ChatCompletionRequest):
+ images, audios = [], []
for message in request.messages:
if isinstance(message.content, list):
- texts = []
for content in message.content:
- if content.type == "text" and content.text:
- texts.append(content.text)
- elif content.type == "image_url" and content.image_url is not None:
+ if content.type == "image_url" and content.image_url is not None:
img = content.image_url.url
if img.startswith("http://") or img.startswith("https://"):
- multimodal_params_dict["images"].append({"type": "url", "data": img})
+ images.append({"type": "url", "data": img})
elif img.startswith("data:image"):
# "data:image/jpeg;base64,{base64_image}"
data_str = img.split(";", 1)[1]
if data_str.startswith("base64,"):
data = data_str[7:]
- multimodal_params_dict["images"].append({"type": "base64", "data": data})
+ images.append({"type": "base64", "data": data})
else:
raise ValueError("Unrecognized image input.")
else:
@@ -200,18 +187,19 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
elif content.type == "audio_url" and content.audio_url is not None:
audio = content.audio_url.url
if audio.startswith("http://") or audio.startswith("https://"):
- multimodal_params_dict["audios"].append({"type": "url", "data": audio})
+ audios.append({"type": "url", "data": audio})
elif audio.startswith("data:audio"):
data_str = audio.split(";", 1)[1]
if data_str.startswith("base64,"):
data = data_str[7:]
- multimodal_params_dict["audios"].append({"type": "base64", "data": data})
+ audios.append({"type": "base64", "data": data})
else:
raise ValueError("Unrecognized audio input.")
else:
raise ValueError("Unrecognized audio input. Supports local path, http url, base64.")
+ return images, audios
- tools = None
+def _get_tools(request: ChatCompletionRequest):
if request.tools and request.tool_choice != "none":
# request.skip_special_tokens = False
if not isinstance(request.tool_choice, str):
@@ -223,6 +211,26 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
else:
tools = [item.function.model_dump() for item in request.tools]
+
+async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response:
+ from .api_http import g_objs
+
+ if request.logit_bias is not None:
+ return create_error_response(
+ HTTPStatus.BAD_REQUEST,
+ "The logit_bias parameter is not currently supported",
+ )
+
+ if request.function_call != "none":
+ return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported")
+
+ created_time = int(time.time())
+
+ images, audios = _get_images_and_audios(request)
+ multimodal_params_dict = {"images": images, "audios": audios}
+
+ tools = _get_tools(request)
+
prompt = await build_prompt(request, tools)
sampling_params_dict = {
"do_sample": request.do_sample,
@@ -526,6 +534,206 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
+async def _get_text_generator_input(request: ChatCompletionRequest):
+ from .api_http import g_objs
+
+ images, audios = _get_images_and_audios(request)
+ multimodal_params_dict = {"images": images, "audios": audios}
+
+ tools = _get_tools(request)
+
+ prompt = await build_prompt(request, tools)
+
+ sampling_params_dict = {
+ "do_sample": request.do_sample,
+ "presence_penalty": request.presence_penalty,
+ "frequency_penalty": request.frequency_penalty,
+ "repetition_penalty": request.repetition_penalty,
+ "temperature": request.temperature,
+ "top_p": request.top_p,
+ "top_k": request.top_k,
+ "ignore_eos": request.ignore_eos,
+ "max_new_tokens": request.max_tokens,
+ "stop_sequences": request.stop,
+ "n": request.n,
+ "best_of": request.n,
+ "add_special_tokens": False,
+ }
+ # Structured output handling
+ if request.response_format:
+ if request.response_format.type == "json_schema":
+ obj = request.response_format.json_schema
+ if obj:
+ # guided_json takes str instead of dict obj
+ sampling_params_dict["guided_json"] = json.dumps(obj.json_schema)
+ elif request.response_format.type == "json_object":
+ sampling_params_dict["guided_grammar"] = "json"
+
+ sampling_params = SamplingParams()
+ sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict)
+
+ sampling_params.verify()
+ multimodal_params = MultimodalParams(**multimodal_params_dict)
+
+ logger.info(f"call text generator with prompt: {prompt} and {sampling_params_dict}")
+
+ return prompt, sampling_params, multimodal_params
+
+
+async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request: Request) -> Response:
+ from .api_http import g_objs
+
+ if request.logit_bias is not None:
+ return create_error_response(
+ HTTPStatus.BAD_REQUEST,
+ "The logit_bias parameter is not currently supported",
+ )
+
+ if request.function_call != "none":
+ return create_error_response(HTTPStatus.BAD_REQUEST, "The function call feature is not supported")
+
+ if request.chat_template_kwargs is None:
+ request.chat_template_kwargs = {}
+ request.chat_template_kwargs.update({"enable_thinking": False})
+
+ chat_request: ChatCompletionRequest = ChatCompletionRequest(**request.model_dump())
+
+ logger.info(f"{type(chat_request)} and {type(request)} and {request.model_dump_json()} and {chat_request.model_dump_json()}")
+ if "image" not in request.modalities:
+ return await chat_completions_impl(chat_request, raw_request)
+
+ if not request.stream or request.n != 1:
+ return create_error_response(HTTPStatus.BAD_REQUEST, "image only support stream api with n = 1")
+
+ image_start_tag = g_objs.httpserver_manager.tokenizer.image_start_tag
+ image_tag = g_objs.httpserver_manager.tokenizer.image_tag
+
+ stop = chat_request.stop or []
+ if isinstance(stop, str): stop = [stop]
+ chat_request.stop = stop.append(image_start_tag)
+
+ created_time = int(time.time())
+
+
+ prompt, sampling_params, multimodal_params = await _get_text_generator_input(chat_request)
+
+ width, height = request.image_config.get_resolution()
+ x2i_params_dict = {
+ "width": width,
+ "height": height,
+ }
+
+ x2i_params = X2IParams()
+ x2i_params.init(**x2i_params_dict)
+
+ async def stream_result() -> AsyncGenerator[bytes, None]:
+ nonlocal prompt
+ from .req_id_generator import convert_sub_id_to_group_id
+ prompt_tokens = 0
+ completion_tokens = 0
+ finish_reason = None
+ group_request_id = None
+
+ while True:
+ need_call_x2i = False
+ text_generator = g_objs.httpserver_manager.generate(
+ prompt, sampling_params, multimodal_params.clone(), request=raw_request)
+
+ reasoning_parser_dict = {}
+ output_chunk = ""
+
+ async for sub_req_id, request_output, metadata, finish_status in text_generator:
+ prompt_tokens = metadata["prompt_tokens"]
+ completion_tokens += 1
+ if group_request_id is None:
+ group_request_id = convert_sub_id_to_group_id(sub_req_id)
+
+ index = sub_req_id
+ delta = request_output
+ finish_reason = finish_status.get_finish_reason()
+ if delta == image_start_tag:
+ need_call_x2i = True
+ continue
+
+ output_chunk += delta
+
+ # Handle reasoning content
+ if get_env_start_args().reasoning_parser and request.separate_reasoning:
+ reasoning_text, delta = _process_reasoning_stream(
+ index, delta, reasoning_parser_dict, request_output, request
+ )
+ if reasoning_text:
+ choice_data = ChatCompletionStreamResponseChoice(
+ index=0,
+ delta=DeltaMessage(reasoning_content=reasoning_text),
+ finish_reason=None,
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=group_request_id,
+ created=created_time,
+ choices=[choice_data],
+ model=request.model,
+ )
+ yield f"data: {chunk.model_dump_json()}\n\n"
+
+
+ delta_message = DeltaMessage(role="assistant", content=delta)
+ if finish_status.is_finished():
+ finish_reason = finish_status.get_finish_reason()
+ stream_choice = ChatCompletionStreamResponseChoice(
+ index=0, delta=delta_message, finish_reason=finish_reason
+ )
+ stream_resp = ChatCompletionStreamResponse(
+ id=group_request_id,
+ created=created_time,
+ model=request.model,
+ choices=[stream_choice],
+ )
+ yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8")
+
+
+ if need_call_x2i:
+ prompt += output_chunk
+
+ images = await g_objs.httpserver_manager.generate_image(
+ prompt, x2i_params, multimodal_params.clone(), request=raw_request)
+
+ delta_message = DeltaMessage(role="assistant", content=[])
+ for image in images:
+ message_content = MessageContent(type="image_url", image_url=ImageURL(url=image))
+ delta_message.content.append(message_content)
+ prompt += image_tag
+ multimodal_params.add_image({"type": "base64", "data": image})
+
+ stream_resp = ChatCompletionStreamResponse(
+ id=group_request_id,
+ created=created_time,
+ model=request.model,
+ choices=[ChatCompletionStreamResponseChoice(index=0, delta=delta_message)],
+ )
+ yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8")
+
+ else:
+ break
+
+ if request.stream_options and request.stream_options.include_usage:
+ usage = UsageInfo(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ usage_chunk = ChatCompletionStreamResponse(
+ id=group_request_id,
+ created=created_time,
+ choices=[],
+ model=request.model,
+ usage=usage,
+ )
+ yield f"data: {usage_chunk.model_dump_json()}\n\n"
+
+ return StreamingResponse(stream_result(), media_type="text/event-stream", background=BackgroundTasks())
+
+
async def completions_impl(request: CompletionRequest, raw_request: Request) -> Response:
from .api_http import g_objs
diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py
index b16f264936..de43cc4cc6 100644
--- a/lightllm/server/core/objs/token_chunck_hash_list.py
+++ b/lightllm/server/core/objs/token_chunck_hash_list.py
@@ -77,8 +77,7 @@ def is_full(self):
return self.size == LIGHTLLM_TOKEN_HASH_LIST_SIZE
def fill(self, data: List[int]):
- assert self.size == 0
- assert len(data) <= LIGHTLLM_TOKEN_HASH_LIST_SIZE
+ assert len(data) <= LIGHTLLM_TOKEN_HASH_LIST_SIZE, f"data size is too large: {len(data)}"
self.items[0 : len(data)] = data
self.size = len(data)
return
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 9502d630f0..aa590aacea 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -477,7 +477,7 @@ async def generation_wrapper(prompt, sample, multimodal, request):
generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
generation_wrapper(prompt_uncondition, sample_params, multimodal_params, request)])
generation_params.update_t2i(con_gen, uncon_gen)
- # use the first reqeust id as the gen image request id
+ # use the first request id as the gen image request id
x2i_req_id = generate_req_ids[0]
generation_params.request_id = x2i_req_id
@@ -488,7 +488,6 @@ async def generation_wrapper(prompt, sample, multimodal, request):
await self.send_to_x2i.send_pyobj(generation_params, protocol=pickle.HIGHEST_PROTOCOL)
await req_status.event.wait()
-
assert req_status.response is not None
self.req_id_to_x2i_reqs.pop(x2i_req_id, None)
@@ -496,7 +495,7 @@ async def generation_wrapper(prompt, sample, multimodal, request):
return req_status.response.images
except Exception as e:
- logger.error(str(e))
+ logger.error(str(e), exc_info=e)
return []
finally:
diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py
index 999dfc6859..4a3b34e49a 100644
--- a/lightllm/server/multimodal_params.py
+++ b/lightllm/server/multimodal_params.py
@@ -170,6 +170,9 @@ async def verify_and_preload(self, request: Request):
await audio.preload(request)
return
+ def add_image(self, image: dict):
+ self.images.append(ImageItem(**image))
+
def to_dict(self):
ret = {}
ret["images"] = [i.to_dict() for i in self.images]
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index 11e53b5d67..828a5ec35e 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -105,7 +105,7 @@ async def loop_for_fwd(self):
protocol=pickle.HIGHEST_PROTOCOL)
images = []
- logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with x2i_param: {x2i_param}")
+ logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {x2i_param}")
if is_t2i:
images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param)
else:
From d4aaa77a94ff38f4613de08ecc800a75dcd1264e Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Fri, 27 Mar 2026 17:31:20 +0000
Subject: [PATCH 06/41] nit
---
lightllm/server/api_openai.py | 10 ++++++++--
lightllm/server/api_start.py | 2 +-
lightllm/server/multimodal_params.py | 3 ---
3 files changed, 9 insertions(+), 6 deletions(-)
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index e40bfaca7d..07920fad19 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -598,7 +598,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
chat_request: ChatCompletionRequest = ChatCompletionRequest(**request.model_dump())
- logger.info(f"{type(chat_request)} and {type(request)} and {request.model_dump_json()} and {chat_request.model_dump_json()}")
+ # logger.info(f"{type(chat_request)} and {type(request)} and {request.model_dump_json()} and {chat_request.model_dump_json()}")
if "image" not in request.modalities:
return await chat_completions_impl(chat_request, raw_request)
@@ -633,8 +633,10 @@ async def stream_result() -> AsyncGenerator[bytes, None]:
completion_tokens = 0
finish_reason = None
group_request_id = None
+ max_image_gen_num = 15 # TODO: make this configurable
- while True:
+ while max_image_gen_num > 0:
+ max_image_gen_num -= 1
need_call_x2i = False
text_generator = g_objs.httpserver_manager.generate(
prompt, sampling_params, multimodal_params.clone(), request=raw_request)
@@ -698,6 +700,10 @@ async def stream_result() -> AsyncGenerator[bytes, None]:
images = await g_objs.httpserver_manager.generate_image(
prompt, x2i_params, multimodal_params.clone(), request=raw_request)
+ if len(images) == 0:
+ logger.warning(f"No image generated by x2i: {prompt[-100:]}, exit...")
+ break
+
delta_message = DeltaMessage(role="assistant", content=[])
for image in images:
message_content = MessageContent(type="image_url", image_url=ImageURL(url=image))
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index 358ae425b4..5f43f95d65 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -250,7 +250,7 @@ def normal_or_p_d_start(args):
node_world_size = args.tp // args.nnodes
can_use_ports = alloc_can_use_network_port(
- num=12+ node_world_size + args.visual_dp * (args.visual_tp + 1), used_nccl_ports=already_uesd_ports
+ num=12 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports
)
logger.info(f"alloced ports: {can_use_ports}")
(
diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py
index 4a3b34e49a..20efb1f053 100644
--- a/lightllm/server/multimodal_params.py
+++ b/lightllm/server/multimodal_params.py
@@ -187,8 +187,6 @@ def to_origin_dict(self):
ret["images"] = [i.to_origin_dict() for i in self.images]
ret["audios"] = [a.to_origin_dict() for a in self.audios]
return ret
-<<<<<<< HEAD
-=======
def free(self):
for image in self.images:
@@ -199,4 +197,3 @@ def free(self):
def clone(self):
return MultimodalParams(**self.to_origin_dict())
->>>>>>> 43a488a8 (add naive x2i backend.)
From 99deab60bf051699d2db2f1f6a8010611ae42c11 Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Fri, 27 Mar 2026 18:31:19 +0000
Subject: [PATCH 07/41] fix.
---
lightllm/server/api_start.py | 28 +++++++++----------
.../model_infer/mode_backend/base_backend.py | 1 -
lightllm/server/tokenizer.py | 3 ++
3 files changed, 17 insertions(+), 15 deletions(-)
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index 5f43f95d65..5ce80bf419 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -352,20 +352,20 @@ def normal_or_p_d_start(args):
],
)
- if args.enable_multimodal_x2i:
- from .x2i_server.manager import start_x2i_process, setup_devices
- origin_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
- setup_devices(args)
- process_manager.start_submodule_processes(
- start_funcs=[
- start_x2i_process,
- ],
- start_args=[
- (args,),
- ],
- )
- if origin_devices:
- os.environ["CUDA_VISIBLE_DEVICES"] = origin_devices
+ if args.enable_multimodal_x2i:
+ from .x2i_server.manager import start_x2i_process, setup_devices
+ origin_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
+ setup_devices(args)
+ process_manager.start_submodule_processes(
+ start_funcs=[
+ start_x2i_process,
+ ],
+ start_args=[
+ (args,),
+ ],
+ )
+ if origin_devices:
+ os.environ["CUDA_VISIBLE_DEVICES"] = origin_devices
if args.enable_cpu_cache:
from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager
diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py
index 12e418824b..0cc492de20 100644
--- a/lightllm/server/router/model_infer/mode_backend/base_backend.py
+++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py
@@ -49,7 +49,6 @@
from .multi_level_kv_cache import MultiLevelKvCacheModule
from .past_kv_cache import PastKVCacheModule
-
class ModeBackend:
def __init__(self) -> None:
self.shm_req_manager = ShmReqManager()
diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py
index 2800bf0f6b..87904aa97d 100644
--- a/lightllm/server/tokenizer.py
+++ b/lightllm/server/tokenizer.py
@@ -32,6 +32,7 @@
from ..models.internvl.model import InternvlTokenizer
from ..models.gemma3.model import Gemma3Tokenizer
from ..models.qwen3_omni_moe_thinker.model import QWen3OmniTokenizer
+from ..models.neo_chat_moe.model import NeoChatTokenizer
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"
@@ -122,5 +123,7 @@ def get_tokenizer(
tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
elif model_type == "gemma3":
tokenizer = Gemma3Tokenizer(tokenizer, model_cfg)
+ elif model_type == "neo_chat":
+ tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
return tokenizer
From f4fccf25bddd471660aa5cc6be5ac2110fb67b4a Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Fri, 3 Apr 2026 08:39:42 +0000
Subject: [PATCH 08/41] add lightx2v & openrouter api
---
lightllm/models/neo_chat_moe/infer_struct.py | 2 -
lightllm/server/api_cli.py | 8 +-
lightllm/server/api_http.py | 22 +-
lightllm/server/api_models.py | 29 ++-
lightllm/server/api_openai.py | 206 +++++++++++++++---
lightllm/server/core/objs/sampling_params.py | 6 +
lightllm/server/core/objs/x2i_params.py | 33 ++-
lightllm/server/tokenizer.py | 2 +-
lightllm/server/x2i_server/manager.py | 91 ++++----
.../server/x2i_server/past_kv_cache_client.py | 26 ++-
lightllm/utils/config_utils.py | 2 +
11 files changed, 326 insertions(+), 101 deletions(-)
diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py
index add8abda08..1693bcb964 100644
--- a/lightllm/models/neo_chat_moe/infer_struct.py
+++ b/lightllm/models/neo_chat_moe/infer_struct.py
@@ -66,8 +66,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor:
images = p.get("images", [])
for img in images:
b_image_start_idx.append(img["start_idx"])
- # a = img["start_idx"]
- # print(f"img start_idx: {a}")
b_image_len.append(img["token_num"])
b_image_thwd.append(img["grid_thwd"])
b_image_nums.append(len(images))
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 4c4be78b3c..96697fae22 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -305,7 +305,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--enable_multimodal_x2i",
action="store_true",
- help="Whether or not to allow to generate images (requird --enable_multimodal)."
+ help="Whether or not to allow to generate images (requird --enable_multimodal).",
)
parser.add_argument(
"--x2i_server_used_gpus",
@@ -313,6 +313,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=1,
help="Number of GPUs to use for x2i server (requird --enable_multimodal_x2i).",
)
+ parser.add_argument(
+ "--x2v_gen_model_config",
+ type=str,
+ default=None,
+ help="Path of the x2v config file.",
+ )
parser.add_argument(
"--enable_mps", action="store_true", help="Whether to enable nvidia mps for multimodal service."
)
diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py
index 65571bf0fc..4f7e3c18bd 100755
--- a/lightllm/server/api_http.py
+++ b/lightllm/server/api_http.py
@@ -89,6 +89,7 @@ def set_args(self, args: StartArgs):
if args.enable_multimodal_x2i:
from .api_lightllm import lightllm_generate_image
+
self.g_generate_image_func = lightllm_generate_image
setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::api_server")
@@ -242,15 +243,15 @@ async def compat_generate(request: Request) -> Response:
return await generate(request)
-@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
-async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response:
- if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
- return create_error_response(
- HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface"
- )
+# @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
+# async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response:
+# if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
+# return create_error_response(
+# HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface"
+# )
- resp = await chat_completions_impl(request, raw_request)
- return resp
+# resp = await chat_completions_impl(request, raw_request)
+# return resp
@app.post("/v1/completions", response_model=CompletionResponse)
@@ -263,6 +264,7 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo
resp = await completions_impl(request, raw_request)
return resp
+
@app.post("/generate_image")
async def generate_image(request: Request) -> Response:
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
@@ -275,7 +277,8 @@ async def generate_image(request: Request) -> Response:
except Exception as e:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))
-@app.post("/v2/chat/completions", response_model=ChatCompletionResponse)
+
+@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def completions_v2(request: ChatCompletionRequestV2, raw_request: Request) -> Response:
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
return create_error_response(
@@ -285,6 +288,7 @@ async def completions_v2(request: ChatCompletionRequestV2, raw_request: Request)
resp = await chat_completions_impl_v2(request, raw_request)
return resp
+
@app.get("/tokens")
@app.post("/tokens")
async def tokens(request: Request):
diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py
index 5725799338..737e1b5b3c 100644
--- a/lightllm/server/api_models.py
+++ b/lightllm/server/api_models.py
@@ -278,6 +278,8 @@ class ChatMessage(BaseModel):
content: Optional[Union[str, List[MessageContent]]] = None
reasoning_content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
+ # OpenRouter-style: generated images alongside text; content may include "" placeholders
+ images: Optional[List[MessageContent]] = None
class ChatCompletionResponseChoice(BaseModel):
@@ -304,6 +306,7 @@ class DeltaMessage(BaseModel):
content: Optional[Union[str, List[MessageContent]]] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
reasoning_content: Optional[str] = None
+ images: Optional[List[MessageContent]] = None
class ChatCompletionStreamResponseChoice(BaseModel):
@@ -374,20 +377,36 @@ def ensure_id_is_str(cls, v):
# Supported values
AspectRatio: TypeAlias = Literal[
- "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4",
- "9:16", "16:9", "21:9",
+ "1:1",
+ "2:3",
+ "3:2",
+ "3:4",
+ "4:3",
+ "4:5",
+ "5:4",
+ "9:16",
+ "16:9",
+ "21:9",
]
-ImageSize: TypeAlias = Literal["0.5K", "1K", "2K", "4K"]
+ImageSize: TypeAlias = Literal["0.5K", "1K", "2K", "4K"]
Modality: TypeAlias = Literal["text", "image", "audio"]
ImageType: TypeAlias = Literal["png", "jpeg", "webp"]
+
class ImageConfig(BaseModel):
aspect_ratio: AspectRatio = "1:1"
image_size: ImageSize = "1K"
image_type: ImageType = "jpeg"
+ # X2I / diffusion sampling (optional; server defaults apply when omitted)
+ steps: Optional[int] = None
+ guidance_scale: Optional[float] = None
+ image_guidance_scale: Optional[float] = None
+ seed: Optional[int] = None
+ num_images: Optional[int] = None
+ cfg_norm: Optional[Literal["none", "cfg_zero_star", "global", "text_channel", "channel"]] = None
# Mapping to actual resolutions (base resolution for 1K)
_aspect_ratio_to_resolution: ClassVar[dict] = {
@@ -427,7 +446,7 @@ def validate_image_size(cls, v):
@field_validator("image_type")
@classmethod
def validate_image_type(cls, v):
- if v not in ['jpeg', 'png', 'webp']:
+ if v not in ["jpeg", "png", "webp"]:
raise ValueError(f"unsupported image type: {v}")
return v
@@ -463,4 +482,4 @@ def validate_image_config(self):
else:
if self.image_config is not None:
raise ValueError("image_config provided but 'image' not in modalities")
- return self
\ No newline at end of file
+ return self
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index 07920fad19..9ea4b13bc4 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -163,6 +163,7 @@ def _process_tools_stream(index: int, delta: str, parser_dict: Dict, request: Ch
normal_text, calls = parser.parse_stream_chunk(delta)
return normal_text, calls
+
def _get_images_and_audios(request: ChatCompletionRequest):
images, audios = [], []
for message in request.messages:
@@ -199,6 +200,7 @@ def _get_images_and_audios(request: ChatCompletionRequest):
raise ValueError("Unrecognized audio input. Supports local path, http url, base64.")
return images, audios
+
def _get_tools(request: ChatCompletionRequest):
if request.tools and request.tool_choice != "none":
# request.skip_special_tokens = False
@@ -210,6 +212,7 @@ def _get_tools(request: ChatCompletionRequest):
]
else:
tools = [item.function.model_dump() for item in request.tools]
+ return tools
async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Request) -> Response:
@@ -559,7 +562,7 @@ async def _get_text_generator_input(request: ChatCompletionRequest):
"best_of": request.n,
"add_special_tokens": False,
}
- # Structured output handling
+ # Structured output handling
if request.response_format:
if request.response_format.type == "json_schema":
obj = request.response_format.json_schema
@@ -580,6 +583,43 @@ async def _get_text_generator_input(request: ChatCompletionRequest):
return prompt, sampling_params, multimodal_params
+def _raw_image_to_data_url(image: Union[str, bytes], image_type: str) -> str:
+ mime = {"jpeg": "image/jpeg", "png": "image/png", "webp": "image/webp"}[image_type]
+ if isinstance(image, bytes):
+ b64 = base64.b64encode(image).decode("ascii")
+ else:
+ if image.startswith("data:"):
+ return image
+ b64 = image
+ return f"data:{mime};base64,{b64}"
+
+
+def _message_contents_from_raw_images(images: List[Any], image_type: str) -> List[MessageContent]:
+ return [
+ MessageContent(
+ type="image_url",
+ image_url=ImageURL(url=_raw_image_to_data_url(img, image_type)),
+ )
+ for img in images
+ ]
+
+
+def _normalize_image_b64_for_multimodal(image: Union[str, bytes]) -> str:
+ if isinstance(image, bytes):
+ return base64.b64encode(image).decode("ascii")
+ return image
+
+
+def _apply_image_generation_stop(chat_request: ChatCompletionRequest, image_start_tag: str) -> None:
+ stop = chat_request.stop or []
+ if isinstance(stop, str):
+ stop = [stop]
+ stop = list(stop)
+ if image_start_tag not in stop:
+ stop.append(image_start_tag)
+ chat_request.stop = stop
+
+
async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request: Request) -> Response:
from .api_http import g_objs
@@ -598,48 +638,135 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
chat_request: ChatCompletionRequest = ChatCompletionRequest(**request.model_dump())
- # logger.info(f"{type(chat_request)} and {type(request)} and {request.model_dump_json()} and {chat_request.model_dump_json()}")
if "image" not in request.modalities:
return await chat_completions_impl(chat_request, raw_request)
- if not request.stream or request.n != 1:
- return create_error_response(HTTPStatus.BAD_REQUEST, "image only support stream api with n = 1")
+ if request.n != 1:
+ return create_error_response(
+ HTTPStatus.BAD_REQUEST,
+ "multimodal image generation only supports n = 1",
+ )
image_start_tag = g_objs.httpserver_manager.tokenizer.image_start_tag
image_tag = g_objs.httpserver_manager.tokenizer.image_tag
- stop = chat_request.stop or []
- if isinstance(stop, str): stop = [stop]
- chat_request.stop = stop.append(image_start_tag)
+ _apply_image_generation_stop(chat_request, image_start_tag)
created_time = int(time.time())
-
prompt, sampling_params, multimodal_params = await _get_text_generator_input(chat_request)
- width, height = request.image_config.get_resolution()
- x2i_params_dict = {
- "width": width,
- "height": height,
- }
-
x2i_params = X2IParams()
- x2i_params.init(**x2i_params_dict)
+ x2i_params.init_from_image_config(request.image_config)
+
+ if not request.stream:
+ from .req_id_generator import convert_sub_id_to_group_id
+
+ full_text = ""
+ response_images: List[MessageContent] = []
+ prompt_tokens = 0
+ completion_tokens = 0
+ finish_reason: Optional[str] = "stop"
+ group_request_id = None
+ max_image_gen_num = 15 # TODO: make this configurable
+
+ while max_image_gen_num > 0:
+ max_image_gen_num -= 1
+ need_call_x2i = False
+ output_chunk = ""
+ text_generator = g_objs.httpserver_manager.generate(
+ prompt, sampling_params, multimodal_params.clone(), request=raw_request
+ )
+ async for sub_req_id, request_output, metadata, finish_status in text_generator:
+ prompt_tokens = metadata["prompt_tokens"]
+ completion_tokens += 1
+ if group_request_id is None:
+ group_request_id = convert_sub_id_to_group_id(sub_req_id)
+ delta = request_output
+ if finish_status.is_finished():
+ finish_reason = finish_status.get_finish_reason()
+ if delta == image_start_tag:
+ need_call_x2i = True
+ continue
+ output_chunk += delta
+
+ full_text += output_chunk
+
+ if need_call_x2i:
+ prompt += output_chunk
+ images = await g_objs.httpserver_manager.generate_image(
+ prompt, x2i_params, multimodal_params.clone(), request=raw_request
+ )
+ if len(images) == 0:
+ logger.warning(f"No image generated by x2i: {prompt[-100:]}, exit...")
+ break
+ response_images.extend(_message_contents_from_raw_images(images, request.image_config.image_type))
+ for image in images:
+ prompt += image_tag
+ full_text += image_tag
+ multimodal_params.add_image({"type": "base64", "data": _normalize_image_b64_for_multimodal(image)})
+ else:
+ break
+
+ reasoning_text = None
+ text_out = full_text
+ reasoning_parser = get_env_start_args().reasoning_parser
+ if reasoning_parser and request.separate_reasoning:
+ request_enable_reasoning = _get_reasoning_from_request(request)
+ try:
+ parser = ReasoningParser(
+ model_type=reasoning_parser,
+ stream_reasoning=False,
+ force_reasoning=request_enable_reasoning,
+ )
+ reasoning_text, text_out = parser.parse_non_stream(full_text)
+ except Exception as e:
+ logger.error(f"Reasoning parsing error: {e}")
+ return create_error_response(
+ HTTPStatus.BAD_REQUEST,
+ "Failed to parse reasoning content!",
+ )
+
+ usage = UsageInfo(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ chat_message = ChatMessage(
+ role="assistant",
+ content=text_out,
+ reasoning_content=reasoning_text if reasoning_text else "",
+ images=response_images if response_images else None,
+ )
+ choice = ChatCompletionResponseChoice(
+ index=0,
+ message=chat_message,
+ finish_reason=finish_reason,
+ )
+ return ChatCompletionResponse(
+ id=group_request_id or f"chatcmpl-{uuid.uuid4().hex}",
+ created=created_time,
+ model=request.model,
+ choices=[choice],
+ usage=usage,
+ )
async def stream_result() -> AsyncGenerator[bytes, None]:
nonlocal prompt
from .req_id_generator import convert_sub_id_to_group_id
+
prompt_tokens = 0
completion_tokens = 0
finish_reason = None
group_request_id = None
- max_image_gen_num = 15 # TODO: make this configurable
+ max_image_gen_num = 15 # TODO: make this configurable
while max_image_gen_num > 0:
max_image_gen_num -= 1
need_call_x2i = False
text_generator = g_objs.httpserver_manager.generate(
- prompt, sampling_params, multimodal_params.clone(), request=raw_request)
+ prompt, sampling_params, multimodal_params.clone(), request=raw_request
+ )
reasoning_parser_dict = {}
output_chunk = ""
@@ -659,7 +786,7 @@ async def stream_result() -> AsyncGenerator[bytes, None]:
output_chunk += delta
- # Handle reasoning content
+ # Handle reasoning content
if get_env_start_args().reasoning_parser and request.separate_reasoning:
reasoning_text, delta = _process_reasoning_stream(
index, delta, reasoning_parser_dict, request_output, request
@@ -678,7 +805,6 @@ async def stream_result() -> AsyncGenerator[bytes, None]:
)
yield f"data: {chunk.model_dump_json()}\n\n"
-
delta_message = DeltaMessage(role="assistant", content=delta)
if finish_status.is_finished():
finish_reason = finish_status.get_finish_reason()
@@ -691,33 +817,51 @@ async def stream_result() -> AsyncGenerator[bytes, None]:
model=request.model,
choices=[stream_choice],
)
- yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8")
-
+ yield ("data: " + json.dumps(stream_resp.model_dump(), ensure_ascii=False) + "\n\n").encode("utf-8")
if need_call_x2i:
prompt += output_chunk
images = await g_objs.httpserver_manager.generate_image(
- prompt, x2i_params, multimodal_params.clone(), request=raw_request)
+ prompt, x2i_params, multimodal_params.clone(), request=raw_request
+ )
if len(images) == 0:
logger.warning(f"No image generated by x2i: {prompt[-100:]}, exit...")
break
- delta_message = DeltaMessage(role="assistant", content=[])
- for image in images:
- message_content = MessageContent(type="image_url", image_url=ImageURL(url=image))
- delta_message.content.append(message_content)
- prompt += image_tag
- multimodal_params.add_image({"type": "base64", "data": image})
+ tag_chunk = ChatCompletionStreamResponse(
+ id=group_request_id,
+ created=created_time,
+ model=request.model,
+ choices=[
+ ChatCompletionStreamResponseChoice(
+ index=0,
+ delta=DeltaMessage(role="assistant", content=image_tag),
+ finish_reason=None,
+ )
+ ],
+ )
+ yield f"data: {tag_chunk.model_dump_json()}\n\n"
- stream_resp = ChatCompletionStreamResponse(
+ img_items = _message_contents_from_raw_images(images, request.image_config.image_type)
+ img_chunk = ChatCompletionStreamResponse(
id=group_request_id,
created=created_time,
model=request.model,
- choices=[ChatCompletionStreamResponseChoice(index=0, delta=delta_message)],
+ choices=[
+ ChatCompletionStreamResponseChoice(
+ index=0,
+ delta=DeltaMessage(role="assistant", images=img_items),
+ finish_reason=None,
+ )
+ ],
)
- yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8")
+ yield f"data: {img_chunk.model_dump_json()}\n\n"
+
+ for image in images:
+ prompt += image_tag
+ multimodal_params.add_image({"type": "base64", "data": _normalize_image_b64_for_multimodal(image)})
else:
break
diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py
index 3cf5cb0887..85112a0379 100644
--- a/lightllm/server/core/objs/sampling_params.py
+++ b/lightllm/server/core/objs/sampling_params.py
@@ -293,6 +293,8 @@ class SamplingParams(ctypes.Structure):
("ignore_eos", ctypes.c_bool),
# the max number of image patches to be used in the internvl model, for the test
("image_max_patch_num", ctypes.c_int),
+ ("min_pixels", ctypes.c_int),
+ ("max_pixels", ctypes.c_int),
("max_new_tokens", ctypes.c_int),
("min_new_tokens", ctypes.c_int),
# Whether to count input tokens for presence_penalty, frequency_penalty and repetition_penalty
@@ -346,6 +348,8 @@ def init(self, tokenizer, **kwargs):
self.top_k = kwargs.get("top_k", SamplingParams._top_k)
self.ignore_eos = kwargs.get("ignore_eos", False)
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
+ self.min_pixels = kwargs.get("min_pixels", -1)
+ self.max_pixels = kwargs.get("max_pixels", -1)
self.max_new_tokens = kwargs.get("max_new_tokens", 16384)
self.min_new_tokens = kwargs.get("min_new_tokens", 1)
self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
@@ -487,6 +491,8 @@ def to_dict(self):
"top_k": self.top_k,
"ignore_eos": self.ignore_eos,
"image_max_patch_num": self.image_max_patch_num,
+ "min_pixels": self.min_pixels,
+ "max_pixels": self.max_pixels,
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
"exponential_decay_length_penalty": self.exponential_decay_length_penalty.to_tuple(),
diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py
index 50b2595ec2..9a29ef949d 100644
--- a/lightllm/server/core/objs/x2i_params.py
+++ b/lightllm/server/core/objs/x2i_params.py
@@ -1,9 +1,10 @@
import ctypes
from dataclasses import dataclass
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional
from enum import IntEnum
from .token_chunck_hash_list import PastKVCachePageList
+
class CfgNormType(IntEnum):
NONE = 0
CFG_ZERO_STAR = 1
@@ -52,10 +53,11 @@ class X2IParams(ctypes.Structure):
_num_images: int = 1
_cfg_norm: CfgNormType = CfgNormType.NONE
- def init(self, **kwargs):
+ def init(self, **kwargs):
def _get(key, default):
v = kwargs.get(key)
return v if v is not None else default
+
self.width = _get("width", X2IParams._width)
self.height = _get("height", X2IParams._height)
self.steps = _get("steps", X2IParams._steps)
@@ -70,6 +72,31 @@ def _get(key, default):
self.total_prompt_tokens = 0
self.request_id = 0
+ def init_from_image_config(self, image_config: Any) -> None:
+ """从 HTTP `image_config`(api_models.ImageConfig)填充,与 `init(**kwargs)` 共用默认值逻辑。"""
+ from lightllm.server.api_models import ImageConfig
+
+ if not isinstance(image_config, ImageConfig):
+ raise TypeError(f"expected ImageConfig, got {type(image_config)!r}")
+ w, h = image_config.get_resolution()
+ kwargs: Dict[str, Any] = {"width": w, "height": h}
+ if image_config.steps is not None:
+ kwargs["steps"] = image_config.steps
+ if image_config.guidance_scale is not None:
+ kwargs["guidance_scale"] = image_config.guidance_scale
+ if image_config.image_guidance_scale is not None:
+ kwargs["image_guidance_scale"] = image_config.image_guidance_scale
+ if image_config.seed is not None:
+ kwargs["seed"] = image_config.seed
+ if image_config.num_images is not None:
+ kwargs["num_images"] = image_config.num_images
+ if image_config.cfg_norm is not None:
+ for e in CfgNormType:
+ if e.as_str() == image_config.cfg_norm:
+ kwargs["cfg_norm"] = e
+ break
+ self.init(**kwargs)
+
def update(self, past_kv: PastKVCachePageList, meta: Dict):
item: PastKVCacheItem = meta.get("kv_cache_item")
past_kv.token_len = item.token_len
@@ -119,4 +146,4 @@ class PastKVCacheItem:
token_len: int
img_tokens: int
img_len: int
- page_indexes: List[int]
\ No newline at end of file
+ page_indexes: List[int]
diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py
index 87904aa97d..18f7778bbb 100644
--- a/lightllm/server/tokenizer.py
+++ b/lightllm/server/tokenizer.py
@@ -125,5 +125,5 @@ def get_tokenizer(
tokenizer = Gemma3Tokenizer(tokenizer, model_cfg)
elif model_type == "neo_chat":
tokenizer = NeoChatTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name)
-
+ print(type(tokenizer), tokenizer, flush=True)
return tokenizer
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index 828a5ec35e..5fe817f175 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -22,14 +22,15 @@
logger = init_logger(__name__)
-'''
+"""
manage a generation service,
1. start x2v pipelines
2. receive generation request from http_server.
3. call llm gen to obtain past key values
4. call x2v to generate images and pass the key values to it
5. return the generated images.
-'''
+"""
+
class X2IManager:
def __init__(
@@ -50,30 +51,43 @@ def __init__(
self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True)
async def wait_to_model_ready(self):
- # from lightx2v import LightX2VPipeline
- # self.gen_pipe = LightX2VPipeline(
- # model_path = self.args.model_dir,
- # model_cls = self.args.model_name,
- # task="t2i"
- # )
- # self.gen_pipe.create_generator(
- # config_json = self.args.x2v_gen_model_config,
- # )
-
- from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I
-
- self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device())
-
+ from lightx2v import LightX2VPipeline
+
+ self.gen_pipe = LightX2VPipeline(
+ model_path=self.args.model_dir,
+ model_cls="neopp",
+ support_tasks=["t2i", "i2i"],
+ )
+ self.gen_pipe.create_generator(
+ config_json=self.args.x2v_gen_model_config,
+ )
+ self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False})
+ # from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I
+ # self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device())
pass
async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams):
- images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param)
- return images
+ print(past_kv_cache.shape, past_kv_cache_text.shape, param, flush=True)
+ self.gen_pipe.runner.set_kvcache_t2i(past_kv_cache, past_kv_cache_text)
+ image = self.gen_pipe.generate(
+ seed=param.seed,
+ task="t2i",
+ save_result_path="", # 返回base64,不需要指定路径了
+ target_shape=[param.height, param.width], # Height, Width
+ )
+ # images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param)
+ return [image]
async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams):
- images = self.naive_x2i.it2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img, param)
- return images
-
+ self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img)
+ image = self.gen_pipe.generate(
+ seed=param.seed,
+ task="i2i",
+ save_result_path="", # 返回base64,不需要指定路径了
+ target_shape=[param.height, param.width], # Height, Width
+ )
+ # images = self.naive_x2i.it2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img, param)
+ return [image]
async def loop_for_fwd(self):
while True:
@@ -94,15 +108,15 @@ async def loop_for_fwd(self):
is_t2i = x2i_param.past_kvcache_img.is_empty()
past_kv_cache_img = None
- if not is_t2i: # t2i
+ if not is_t2i: # t2i
past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i(
x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len
)
# release
self.send_to_httpserver.send_pyobj(
- X2ICacheRelease(request_id=x2i_param.request_id),
- protocol=pickle.HIGHEST_PROTOCOL)
+ X2ICacheRelease(request_id=x2i_param.request_id), protocol=pickle.HIGHEST_PROTOCOL
+ )
images = []
logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {x2i_param}")
@@ -111,20 +125,17 @@ async def loop_for_fwd(self):
else:
images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param)
- self.send_to_httpserver.send_pyobj(X2IResponse(
- request_id=x2i_param.request_id,
- images=images),
- protocol=pickle.HIGHEST_PROTOCOL)
+ self.send_to_httpserver.send_pyobj(
+ X2IResponse(request_id=x2i_param.request_id, images=images), protocol=pickle.HIGHEST_PROTOCOL
+ )
except Exception as e:
- self.send_to_httpserver.send_pyobj(X2IResponse(
- request_id=x2i_param.request_id,
- images=None),
- protocol=pickle.HIGHEST_PROTOCOL)
+ self.send_to_httpserver.send_pyobj(
+ X2IResponse(request_id=x2i_param.request_id, images=None), protocol=pickle.HIGHEST_PROTOCOL
+ )
logger.error(e)
-
async def loop_for_netio_req(self):
while True:
try:
@@ -139,6 +150,7 @@ async def loop_for_netio_req(self):
def clean_up(self):
pass
+
def setup_devices(args: StartArgs):
devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
logger.info(f"current devices: {devices} {torch.cuda.device_count()}")
@@ -152,11 +164,12 @@ def setup_devices(args: StartArgs):
if len(devices) < llm_need_gpus + x2i_need_gpus:
raise ValueError(f"devices {devices} not enough, need {llm_need_gpus} and {x2i_need_gpus}")
- os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices[
- llm_need_gpus:llm_need_gpus + x2i_need_gpus]))
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, devices[llm_need_gpus : llm_need_gpus + x2i_need_gpus]))
- logger.info(f"setup devices for x2i server: {os.environ['CUDA_VISIBLE_DEVICES']}, "
- f"{torch.cuda.device_count()} {torch.cuda.current_device()}")
+ logger.info(
+ f"setup devices for x2i server: {os.environ['CUDA_VISIBLE_DEVICES']}, "
+ f"{torch.cuda.device_count()} {torch.cuda.current_device()}"
+ )
def start_x2i_process(args, pipe_writer):
@@ -166,7 +179,9 @@ def start_x2i_process(args, pipe_writer):
start_parent_check_thread()
set_current_device_id(torch.cuda.current_device())
try:
- x2iserver = X2IManager(args=args,)
+ x2iserver = X2IManager(
+ args=args,
+ )
asyncio.run(x2iserver.wait_to_model_ready())
except Exception as e:
logger.exception(str(e))
diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py
index b06db8da3c..56781f49ab 100644
--- a/lightllm/server/x2i_server/past_kv_cache_client.py
+++ b/lightllm/server/x2i_server/past_kv_cache_client.py
@@ -58,7 +58,8 @@ def allocate_pages(self, req_id: int, need_tokens: int, img_tokens: int = 0, img
page_indexes, self.free_pages = self.free_pages[:need_pages], self.free_pages[need_pages:]
self.allocated_pages_dict[req_id] = PastKVCacheItem(
- req_id=req_id, token_len=need_tokens, page_indexes=page_indexes, img_tokens=img_tokens, img_len=img_len)
+ req_id=req_id, token_len=need_tokens, page_indexes=page_indexes, img_tokens=img_tokens, img_len=img_len
+ )
return page_indexes
@@ -77,17 +78,20 @@ def get_pages_by_req_id(self, req_id: int) -> Optional[List[int]]:
def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optional[torch.Tensor]:
if page_indexes is None:
return None
- assert token_num <= len(page_indexes) * self.token_page_size and \
- token_num > (len(page_indexes) - 1) * self.token_page_size
+ assert (
+ token_num <= len(page_indexes) * self.token_page_size
+ and token_num > (len(page_indexes) - 1) * self.token_page_size
+ )
(P, L, S, H, D) = self.cpu_kv_cache_tensor[page_indexes].shape
- # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (2, L, H // 2, P, S, D)
- # -> (2, L, H // 2, P * S, D) -> ( L, 2, H // 2, P * S, D)
- kv = self.cpu_kv_cache_tensor[page_indexes] \
- .view(P, L, S, 2, H // 2, D) \
- .permute(3, 1, 4, 0, 2, 5).contiguous() \
- .view(2, L, H // 2, P * S, D) \
- .permute(1, 0, 2, 3, 4)
- return kv[:, :, :, :token_num, :].contiguous()
+ # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2, H // 2, P, S, D) -> (L, 2, H // 2, P * S, D)
+ kv = (
+ self.cpu_kv_cache_tensor[page_indexes]
+ .view(P, L, S, 2, H // 2, D)
+ .permute(1, 3, 0, 2, 4, 5)
+ .contiguous()
+ .view(L, 2, P * S, H // 2, D)
+ )
+ return kv
def _create_shm_cpu_kv_cache(self):
shm_ptr = create_shm_kv_cache_ptr(
diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py
index 54f20384b7..a9088d9a2e 100644
--- a/lightllm/utils/config_utils.py
+++ b/lightllm/utils/config_utils.py
@@ -186,6 +186,8 @@ def has_vision_module(model_path: str) -> bool:
):
# Qwen3OmniMoeVisionTransformerPretrainedModel
return True
+ elif model_type == "neo_chat":
+ return True
else:
raise Exception("unknown vision model type")
except:
From 06571d0e78afb3b3fc3915587f5275dc625903ee Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Tue, 7 Apr 2026 06:21:13 +0000
Subject: [PATCH 09/41] fix x2v && add openai image-only
---
lightllm/models/neo_chat_moe/model.py | 6 +++
lightllm/server/api_models.py | 4 +-
lightllm/server/api_openai.py | 48 +++++++++++++++++--
lightllm/server/x2i_server/manager.py | 21 +++++++-
.../server/x2i_server/past_kv_cache_client.py | 2 +-
5 files changed, 74 insertions(+), 7 deletions(-)
diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py
index ec8772cc72..900d2eb028 100644
--- a/lightllm/models/neo_chat_moe/model.py
+++ b/lightllm/models/neo_chat_moe/model.py
@@ -48,6 +48,7 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
def load_conversion_module(self, model_dir: str):
import importlib
+
conversion_path = os.path.join(model_dir, "conversation.py")
if not os.path.exists(conversion_path):
return None
@@ -151,6 +152,9 @@ def get_query_for_it2i(self, prompt: str):
query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt
query_text_uncondition = self._build_t2i_query(IMG_TOKEN * image_len)
question_img_uncondition = self._build_t2i_query("")
+ print(f"query_condition: {query_condition}")
+ print(f"query_text_uncondition: {query_text_uncondition}")
+ print(f"question_img_uncondition: {question_img_uncondition}")
return query_condition, query_text_uncondition, question_img_uncondition
def get_query_for_t2i(self, prompt: str):
@@ -160,6 +164,8 @@ def get_query_for_t2i(self, prompt: str):
# f"Please generate an image based on the following description: {prompt}",
# thinking_content="\n\n\n\n")
query_uncondition = self._build_t2i_query("")
+ print(f"query_condition: {query_condition}", flush=True)
+ print(f"query_uncondition: {query_uncondition}", flush=True)
return query_condition, query_uncondition
diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py
index 737e1b5b3c..643bba66e3 100644
--- a/lightllm/server/api_models.py
+++ b/lightllm/server/api_models.py
@@ -468,8 +468,8 @@ class ChatCompletionRequestV2(ChatCompletionRequest):
@field_validator("modalities")
@classmethod
def validate_modalities(cls, v):
- if "text" not in v:
- raise ValueError("modalities must include 'text'")
+ if "text" not in v and v != ["image"]:
+ raise ValueError("modalities must include 'text', or be ['image'] for image-only generation")
if len(v) != len(set(v)):
raise ValueError("modalities must be unique")
return v
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index 9ea4b13bc4..0f364cdc1e 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -202,6 +202,7 @@ def _get_images_and_audios(request: ChatCompletionRequest):
def _get_tools(request: ChatCompletionRequest):
+ tools = None
if request.tools and request.tool_choice != "none":
# request.skip_special_tokens = False
if not isinstance(request.tool_choice, str):
@@ -537,15 +538,17 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
-async def _get_text_generator_input(request: ChatCompletionRequest):
+async def _get_text_generator_input(request: ChatCompletionRequest, apply_chat_template: bool = True):
from .api_http import g_objs
images, audios = _get_images_and_audios(request)
multimodal_params_dict = {"images": images, "audios": audios}
tools = _get_tools(request)
-
- prompt = await build_prompt(request, tools)
+ if apply_chat_template:
+ prompt = await build_prompt(request, tools)
+ else:
+ prompt = request.messages[-1].content
sampling_params_dict = {
"do_sample": request.do_sample,
@@ -620,6 +623,42 @@ def _apply_image_generation_stop(chat_request: ChatCompletionRequest, image_star
chat_request.stop = stop
+async def _chat_completion_image_only(
+ request: ChatCompletionRequestV2,
+ raw_request: Request,
+) -> ChatCompletionResponse:
+ from .api_http import g_objs
+
+ created_time = int(time.time())
+ x2i_params = X2IParams()
+ x2i_params.init_from_image_config(request.image_config)
+
+ prompt, _, multimodal_params = await _get_text_generator_input(request, apply_chat_template=False)
+
+ images = await g_objs.httpserver_manager.generate_image(
+ prompt, x2i_params, multimodal_params.clone(), request=raw_request
+ )
+ response_images = _message_contents_from_raw_images(images, request.image_config.image_type)
+ chat_message = ChatMessage(
+ role="assistant",
+ content=prompt,
+ images=response_images if response_images else None,
+ )
+ choice = ChatCompletionResponseChoice(
+ index=0,
+ message=chat_message,
+ finish_reason="stop",
+ )
+ usage = UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0)
+ return ChatCompletionResponse(
+ id=f"chatcmpl-{uuid.uuid4().hex}",
+ created=created_time,
+ model=request.model,
+ choices=[choice],
+ usage=usage,
+ )
+
+
async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request: Request) -> Response:
from .api_http import g_objs
@@ -641,6 +680,9 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
if "image" not in request.modalities:
return await chat_completions_impl(chat_request, raw_request)
+ if request.modalities == ["image"]:
+ return await _chat_completion_image_only(request, raw_request)
+
if request.n != 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index 5fe817f175..e9e5e6de4b 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -67,7 +67,12 @@ async def wait_to_model_ready(self):
pass
async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams):
- print(past_kv_cache.shape, past_kv_cache_text.shape, param, flush=True)
+ past_kv_cache = self._truncate_kv_cache_to_compressed_len(
+ past_kv_cache, param.past_kvcache.get_compressed_len()
+ )
+ past_kv_cache_text = self._truncate_kv_cache_to_compressed_len(
+ past_kv_cache_text, param.past_kvcache_text.get_compressed_len()
+ )
self.gen_pipe.runner.set_kvcache_t2i(past_kv_cache, past_kv_cache_text)
image = self.gen_pipe.generate(
seed=param.seed,
@@ -79,6 +84,15 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams
return [image]
async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams):
+ past_kv_cache = self._truncate_kv_cache_to_compressed_len(
+ past_kv_cache, param.past_kvcache.get_compressed_len()
+ )
+ past_kv_cache_text = self._truncate_kv_cache_to_compressed_len(
+ past_kv_cache_text, param.past_kvcache_text.get_compressed_len()
+ )
+ past_kv_cache_img = self._truncate_kv_cache_to_compressed_len(
+ past_kv_cache_img, param.past_kvcache_img.get_compressed_len()
+ )
self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img)
image = self.gen_pipe.generate(
seed=param.seed,
@@ -150,6 +164,11 @@ async def loop_for_netio_req(self):
def clean_up(self):
pass
+ def _truncate_kv_cache_to_compressed_len(self, kv: torch.Tensor, compressed_len: int) -> torch.Tensor:
+ seq = kv.shape[2]
+ n = min(compressed_len, seq)
+ return kv[:, :, :n:, :].contiguous()
+
def setup_devices(args: StartArgs):
devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py
index 56781f49ab..9e9f0d99d4 100644
--- a/lightllm/server/x2i_server/past_kv_cache_client.py
+++ b/lightllm/server/x2i_server/past_kv_cache_client.py
@@ -83,7 +83,7 @@ def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optio
and token_num > (len(page_indexes) - 1) * self.token_page_size
)
(P, L, S, H, D) = self.cpu_kv_cache_tensor[page_indexes].shape
- # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2, H // 2, P, S, D) -> (L, 2, H // 2, P * S, D)
+ # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2,P, S, H // 2, D) -> (L, 2, P * S, H // 2, D)
kv = (
self.cpu_kv_cache_tensor[page_indexes]
.view(P, L, S, 2, H // 2, D)
From 0619f9dcab86fb3daf65f67887b4bcf61249ae9d Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Wed, 8 Apr 2026 07:09:02 +0000
Subject: [PATCH 10/41] add naive option.
---
lightllm/server/api_cli.py | 5 +
lightllm/server/core/objs/start_args_type.py | 1 +
lightllm/server/core/objs/x2i_params.py | 7 +-
lightllm/server/x2i_server/manager.py | 25 +++--
.../x2i_server/naive/modeling_neo_chat.py | 101 +++++++++++++-----
.../server/x2i_server/past_kv_cache_client.py | 27 +++--
6 files changed, 121 insertions(+), 45 deletions(-)
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 96697fae22..5cf656d440 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -313,6 +313,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=1,
help="Number of GPUs to use for x2i server (requird --enable_multimodal_x2i).",
)
+ parser.add_argument(
+ "--x2i_use_naive_impl",
+ action="store_true",
+ help="Whether to use the native backend for x2i generation. If set, it will use the naive pytorch backend mainly for testing and debugging purpose.",
+ )
parser.add_argument(
"--x2v_gen_model_config",
type=str,
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index 5cf13ef19a..7429b22593 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -163,6 +163,7 @@ class StartArgs:
x2i_port: int = field(default=None)
http_server_port_for_x2i: int = field(default=None)
x2i_server_used_gpus: int = field(default=1)
+ x2i_use_naive_impl: bool = field(default=False)
# multi_modal
enable_multimodal_x2i: bool = field(default=False)
diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py
index 9a29ef949d..3e4498ab4b 100644
--- a/lightllm/server/core/objs/x2i_params.py
+++ b/lightllm/server/core/objs/x2i_params.py
@@ -37,6 +37,7 @@ class X2IParams(ctypes.Structure):
("seed", ctypes.c_int),
("num_images", ctypes.c_int),
("cfg_norm", ctypes.c_int),
+ ("timestep_shift", ctypes.c_float),
("past_kvcache", PastKVCachePageList),
("past_kvcache_text", PastKVCachePageList),
("past_kvcache_img", PastKVCachePageList),
@@ -47,11 +48,12 @@ class X2IParams(ctypes.Structure):
_width: int = 1024
_height: int = 1024
_steps: int = 50
- _guidance_scale: float = 7.0
- _image_guidance_scale: float = 7.0
+ _guidance_scale: float = 4.0
+ _image_guidance_scale: float = 1.0
_seed: int = 42
_num_images: int = 1
_cfg_norm: CfgNormType = CfgNormType.NONE
+ _timestep_shift: float = 3.0
def init(self, **kwargs):
def _get(key, default):
@@ -66,6 +68,7 @@ def _get(key, default):
self.seed = _get("seed", X2IParams._seed)
self.num_images = _get("num_images", X2IParams._num_images)
self.cfg_norm = _get("cfg_norm", X2IParams._cfg_norm)
+ self.timestep_shift = _get("timestep_shift", X2IParams._timestep_shift)
self.past_kvcache = PastKVCachePageList()
self.past_kvcache_text = PastKVCachePageList()
self.past_kvcache_img = PastKVCachePageList()
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index e9e5e6de4b..7dac4f774a 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -50,7 +50,14 @@ def __init__(
self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True)
+ self.use_naive_x2i = args.x2i_use_naive_impl
+
async def wait_to_model_ready(self):
+ if self.use_naive_x2i:
+ from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I
+ self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device())
+ return
+
from lightx2v import LightX2VPipeline
self.gen_pipe = LightX2VPipeline(
@@ -62,11 +69,12 @@ async def wait_to_model_ready(self):
config_json=self.args.x2v_gen_model_config,
)
self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False})
- # from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I
- # self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device())
- pass
async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams):
+ if self.use_naive_x2i:
+ images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param)
+ return images
+
past_kv_cache = self._truncate_kv_cache_to_compressed_len(
past_kv_cache, param.past_kvcache.get_compressed_len()
)
@@ -84,6 +92,10 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams
return [image]
async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams):
+ if self.use_naive_x2i:
+ images = self.naive_x2i.it2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img, param)
+ return images
+
past_kv_cache = self._truncate_kv_cache_to_compressed_len(
past_kv_cache, param.past_kvcache.get_compressed_len()
)
@@ -100,7 +112,6 @@ async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_i
save_result_path="", # 返回base64,不需要指定路径了
target_shape=[param.height, param.width], # Height, Width
)
- # images = self.naive_x2i.it2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img, param)
return [image]
async def loop_for_fwd(self):
@@ -113,18 +124,18 @@ async def loop_for_fwd(self):
x2i_param = self.waiting_reqs.pop(0)
past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i(
- x2i_param.past_kvcache.get_all(), x2i_param.past_kvcache.token_len
+ x2i_param.past_kvcache.get_all(), x2i_param.past_kvcache.token_len, self.use_naive_x2i
)
past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i(
- x2i_param.past_kvcache_text.get_all(), x2i_param.past_kvcache_text.token_len
+ x2i_param.past_kvcache_text.get_all(), x2i_param.past_kvcache_text.token_len, self.use_naive_x2i
)
is_t2i = x2i_param.past_kvcache_img.is_empty()
past_kv_cache_img = None
if not is_t2i: # t2i
past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i(
- x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len
+ x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len, self.use_naive_x2i
)
# release
diff --git a/lightllm/server/x2i_server/naive/modeling_neo_chat.py b/lightllm/server/x2i_server/naive/modeling_neo_chat.py
index 890d761391..326e2a574d 100644
--- a/lightllm/server/x2i_server/naive/modeling_neo_chat.py
+++ b/lightllm/server/x2i_server/naive/modeling_neo_chat.py
@@ -372,7 +372,7 @@ def it2i_generate(self,
img_cfg_scale=1,
cfg_norm='none',
enable_timestep_shift=True,
- timestep_shift=1,
+ timestep_shift=3,
image_size=(256, 256),
num_steps=30,
cfg_interval=(0.1, 1.0),
@@ -421,7 +421,7 @@ def it2i_generate(self,
# init noise image tokens
grid_h = image_size[1] // self.patch_size
grid_w = image_size[0] // self.patch_size
- grid_hw = torch.tensor([[grid_h, grid_w]]*batch_size, device=device)
+ grid_hw = torch.tensor([[grid_h, grid_w]] * batch_size, device=device)
noise_scale = self.noise_scale
if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'):
@@ -446,6 +446,7 @@ def it2i_generate(self,
for step_i in range(num_steps):
t = timesteps[step_i]
t_next = timesteps[step_i + 1]
+ use_cfg = t >= cfg_interval[0] and t <= cfg_interval[1]
z = self.patchify(image_prediction, self.patch_size * merge_size)
image_input = self.patchify(image_prediction, self.patch_size, channel_first=True)
@@ -459,27 +460,69 @@ def it2i_generate(self,
image_embeds = image_embeds + timestep_embeddings
v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size)
- if t > cfg_interval[0] and t < cfg_interval[1]:
- if cfg_scale > 1:
- v_pred_text_uncondition = self._t2i_predict_v(image_embeds, indexes_image_text_uncondition, attention_mask_text_uncondition, past_key_values_text_uncondition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size)
- else:
- v_pred_text_uncondition = 0
- if img_cfg_scale > 1:
- v_pred_img_uncondition = self._t2i_predict_v(image_embeds, indexes_image_img_uncondition, attention_mask_img_uncondition, past_key_values_img_uncondition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size)
- else:
- v_pred_img_uncondition = 0
+ if not use_cfg:
+ v_pred = v_pred_condition
+ elif cfg_scale == 1 and img_cfg_scale == 1:
+ v_pred = v_pred_condition
+ elif img_cfg_scale == 1:
+ out_img_cond = self._t2i_predict_v(
+ image_embeds,
+ indexes_image_text_uncondition,
+ attention_mask_text_uncondition,
+ past_key_values_text_uncondition,
+ t,
+ z,
+ image_token_num=token_h * token_w,
+ timestep_embeddings=timestep_embeddings,
+ image_size=image_size,
+ )
+ v_pred = out_img_cond + cfg_scale * (v_pred_condition - out_img_cond)
+ elif cfg_scale == img_cfg_scale:
+ out_uncond = self._t2i_predict_v(
+ image_embeds,
+ indexes_image_img_uncondition,
+ attention_mask_img_uncondition,
+ past_key_values_img_uncondition,
+ t,
+ z,
+ image_token_num=token_h * token_w,
+ timestep_embeddings=timestep_embeddings,
+ image_size=image_size,
+ )
+ v_pred = out_uncond + cfg_scale *(v_pred_condition - out_uncond)
+ else:
+ out_img_cond = self._t2i_predict_v(
+ image_embeds,
+ indexes_image_text_uncondition,
+ attention_mask_text_uncondition,
+ past_key_values_text_uncondition,
+ t,
+ z,
+ image_token_num=token_h * token_w,
+ timestep_embeddings=timestep_embeddings,
+ image_size=image_size,
+ )
+ out_uncond = self._t2i_predict_v(
+ image_embeds,
+ indexes_image_img_uncondition,
+ attention_mask_img_uncondition,
+ past_key_values_img_uncondition,
+ t,
+ z,
+ image_token_num=token_h * token_w,
+ timestep_embeddings=timestep_embeddings,
+ image_size=image_size,
+ )
+ v_pred = (
+ out_uncond
+ + cfg_scale * (v_pred_condition - out_img_cond)
+ + img_cfg_scale * (out_img_cond - out_uncond)
+ )
- if t > cfg_interval[0] and t < cfg_interval[1]:
- v_pred_text = v_pred_text_uncondition + cfg_scale * (v_pred_condition - v_pred_text_uncondition)
- if cfg_norm == 'text_channel':
- norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True)
- norm_v_cfg = torch.norm(v_pred_text, dim=-1, keepdim=True)
- scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
- v_pred_text = v_pred_text * scale
- v_pred = v_pred_img_uncondition + img_cfg_scale * (v_pred_text - v_pred_img_uncondition)
+ if cfg_scale > 1 or img_cfg_scale > 1:
if cfg_norm == 'global':
- norm_v_condition = torch.norm(v_pred_condition, dim=(1,2), keepdim=True)
- norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True)
+ norm_v_condition = torch.norm(v_pred_condition, dim=(1, 2), keepdim=True)
+ norm_v_cfg = torch.norm(v_pred, dim=(1, 2), keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
elif cfg_norm == 'channel':
@@ -488,8 +531,6 @@ def it2i_generate(self,
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
- else:
- v_pred = v_pred_condition
z = z + (t_next - t) * v_pred
@@ -517,7 +558,7 @@ def t2i_generate(self,
past_key_values_uncondition,
text_lens,
cfg_scale=1,
- timestep_shift=1,
+ timestep_shift=3,
enable_timestep_shift=True,
cfg_norm='none',
image_size=(256, 256),
@@ -601,7 +642,7 @@ def t2i_generate(self,
timestep_embeddings=timestep_embeddings, image_size=image_size)
- if t > cfg_interval[0] and t < cfg_interval[1] and cfg_scale > 1:
+ if t >= cfg_interval[0] and t <= cfg_interval[1] and cfg_scale > 1:
v_pred_uncondition = self._t2i_predict_v(image_embeds, indexes_image_uncondition, attention_mask_uncondition, past_key_values_uncondition, t, z, image_token_num=token_h*token_w,
timestep_embeddings=timestep_embeddings, image_size=image_size)
if cfg_norm == 'cfg_zero_star':
@@ -623,6 +664,12 @@ def t2i_generate(self,
norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
+ elif cfg_norm == 'channel':
+ norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True)
+ norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True)
+ scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
+ v_pred = v_pred * scale
+
else:
v_pred = v_pred_condition
@@ -718,7 +765,8 @@ def t2i(self, past_kv, past_kv_txt, param: X2IParams):
cfg_scale=param.guidance_scale,
image_size=(param.width, param.height),
num_steps=param.steps,
- batch_size=param.num_images)
+ batch_size=param.num_images,
+ timestep_shift=param.timestep_shift)
return self._post_process(output)
@@ -750,6 +798,7 @@ def it2i(self, past_kv, past_kv_txt, past_kv_img, param: X2IParams):
image_size=(param.width, param.height),
num_steps=param.steps,
batch_size=param.num_images,
+ timestep_shift=param.timestep_shift,
)
return self._post_process(output)
diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py
index 9e9f0d99d4..707a3643c7 100644
--- a/lightllm/server/x2i_server/past_kv_cache_client.py
+++ b/lightllm/server/x2i_server/past_kv_cache_client.py
@@ -75,7 +75,7 @@ def get_pages_by_req_id(self, req_id: int) -> Optional[List[int]]:
item = self.allocated_pages_dict.get(req_id, None)
return item
- def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optional[torch.Tensor]:
+ def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int, use_naive=False) -> Optional[torch.Tensor]:
if page_indexes is None:
return None
assert (
@@ -83,15 +83,22 @@ def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optio
and token_num > (len(page_indexes) - 1) * self.token_page_size
)
(P, L, S, H, D) = self.cpu_kv_cache_tensor[page_indexes].shape
- # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2,P, S, H // 2, D) -> (L, 2, P * S, H // 2, D)
- kv = (
- self.cpu_kv_cache_tensor[page_indexes]
- .view(P, L, S, 2, H // 2, D)
- .permute(1, 3, 0, 2, 4, 5)
- .contiguous()
- .view(L, 2, P * S, H // 2, D)
- )
- return kv
+ if not use_naive:
+ # (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2,P, S, H // 2, D) -> (L, 2, P * S, H // 2, D)
+ kv = (
+ self.cpu_kv_cache_tensor[page_indexes]
+ .view(P, L, S, 2, H // 2, D)
+ .permute(1, 3, 0, 2, 4, 5)
+ .contiguous()
+ .view(L, 2, P * S, H // 2, D)
+ )
+ return kv
+ else:
+ kv = self.cpu_kv_cache_tensor[page_indexes] \
+ .view(P, L, S, 2, H // 2, D) \
+ .permute(1, 3, 4, 0, 2, 5).contiguous() \
+ .view(L, 2, H // 2, P * S, D)
+ return kv[:, :, :, :token_num, :].contiguous()
def _create_shm_cpu_kv_cache(self):
shm_ptr = create_shm_kv_cache_ptr(
From e958f11d6e95b83072441a44f2b3a3f0640247f8 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Wed, 8 Apr 2026 08:01:00 +0000
Subject: [PATCH 11/41] x2v support
---
lightllm/models/neo_chat_moe/model.py | 11 +---
lightllm/server/api_cli.py | 2 +-
lightllm/server/api_openai.py | 23 ++++----
lightllm/server/core/objs/start_args_type.py | 6 ++-
lightllm/server/core/objs/x2i_params.py | 12 +++--
lightllm/server/httpserver/manager.py | 52 +++++++++++--------
lightllm/server/x2i_server/manager.py | 35 ++++---------
.../server/x2i_server/past_kv_cache_client.py | 4 +-
8 files changed, 69 insertions(+), 76 deletions(-)
diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py
index 900d2eb028..648d696877 100644
--- a/lightllm/models/neo_chat_moe/model.py
+++ b/lightllm/models/neo_chat_moe/model.py
@@ -152,20 +152,13 @@ def get_query_for_it2i(self, prompt: str):
query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt
query_text_uncondition = self._build_t2i_query(IMG_TOKEN * image_len)
question_img_uncondition = self._build_t2i_query("")
- print(f"query_condition: {query_condition}")
- print(f"query_text_uncondition: {query_text_uncondition}")
- print(f"question_img_uncondition: {question_img_uncondition}")
return query_condition, query_text_uncondition, question_img_uncondition
def get_query_for_t2i(self, prompt: str):
# prompt is already applied
+ image_len = prompt.count(IMG_TOKEN)
query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt
- # query_condition = self._build_t2i_query(
- # f"Please generate an image based on the following description: {prompt}",
- # thinking_content="\n\n\n\n")
- query_uncondition = self._build_t2i_query("")
- print(f"query_condition: {query_condition}", flush=True)
- print(f"query_uncondition: {query_uncondition}", flush=True)
+ query_uncondition = self._build_t2i_query("", thinking_content=IMG_TOKEN * image_len)
return query_condition, query_uncondition
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 96697fae22..78a3355872 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -637,7 +637,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--cpu_cache_storage_size",
type=float,
- default=2,
+ default=50,
help="""The capacity of cpu cache. GB used.""",
)
parser.add_argument(
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index 0f364cdc1e..94de694db9 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -538,17 +538,14 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)
-async def _get_text_generator_input(request: ChatCompletionRequest, apply_chat_template: bool = True):
+async def _get_text_generator_input(request: ChatCompletionRequest):
from .api_http import g_objs
images, audios = _get_images_and_audios(request)
multimodal_params_dict = {"images": images, "audios": audios}
tools = _get_tools(request)
- if apply_chat_template:
- prompt = await build_prompt(request, tools)
- else:
- prompt = request.messages[-1].content
+ prompt = await build_prompt(request, tools)
sampling_params_dict = {
"do_sample": request.do_sample,
@@ -626,15 +623,13 @@ def _apply_image_generation_stop(chat_request: ChatCompletionRequest, image_star
async def _chat_completion_image_only(
request: ChatCompletionRequestV2,
raw_request: Request,
+ prompt: str,
+ multimodal_params: MultimodalParams,
+ x2i_params: X2IParams,
) -> ChatCompletionResponse:
from .api_http import g_objs
created_time = int(time.time())
- x2i_params = X2IParams()
- x2i_params.init_from_image_config(request.image_config)
-
- prompt, _, multimodal_params = await _get_text_generator_input(request, apply_chat_template=False)
-
images = await g_objs.httpserver_manager.generate_image(
prompt, x2i_params, multimodal_params.clone(), request=raw_request
)
@@ -680,9 +675,6 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
if "image" not in request.modalities:
return await chat_completions_impl(chat_request, raw_request)
- if request.modalities == ["image"]:
- return await _chat_completion_image_only(request, raw_request)
-
if request.n != 1:
return create_error_response(
HTTPStatus.BAD_REQUEST,
@@ -701,6 +693,9 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
x2i_params = X2IParams()
x2i_params.init_from_image_config(request.image_config)
+ if request.modalities == ["image"]:
+ return await _chat_completion_image_only(request, raw_request, prompt, multimodal_params, x2i_params)
+
if not request.stream:
from .req_id_generator import convert_sub_id_to_group_id
@@ -710,7 +705,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
completion_tokens = 0
finish_reason: Optional[str] = "stop"
group_request_id = None
- max_image_gen_num = 15 # TODO: make this configurable
+ max_image_gen_num = 2 # TODO: make this configurable
while max_image_gen_num > 0:
max_image_gen_num -= 1
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index 5cf13ef19a..6a8d2d2eb3 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -125,7 +125,9 @@ class StartArgs:
vit_att_backend: List[str] = field(
default=("auto",), metadata={"choices": ["auto", "triton", "fa3", "sdpa", "xformers"]}
)
- llm_kv_type: str = field(default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"]})
+ llm_kv_type: str = field(
+ default="None", metadata={"choices": ["None", "int8kv", "int4kv", "fp8kv_sph", "fp8kv_spt"]}
+ )
llm_kv_quant_group_size: int = field(default=8)
sampling_backend: str = field(default="triton", metadata={"choices": ["triton", "sglang_kernel"]})
penalty_counter_mode: str = field(
@@ -144,7 +146,7 @@ class StartArgs:
nixl_pd_kv_page_size: int = field(default=1024)
pd_node_id: int = field(default=-1)
enable_cpu_cache: bool = field(default=False)
- cpu_cache_storage_size: float = field(default=2)
+ cpu_cache_storage_size: float = field(default=20)
cpu_cache_token_page_size: int = field(default=64)
enable_disk_cache: bool = field(default=False)
disk_cache_storage_size: float = field(default=10)
diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py
index 9a29ef949d..7704217a39 100644
--- a/lightllm/server/core/objs/x2i_params.py
+++ b/lightllm/server/core/objs/x2i_params.py
@@ -36,6 +36,8 @@ class X2IParams(ctypes.Structure):
("image_guidance_scale", ctypes.c_float),
("seed", ctypes.c_int),
("num_images", ctypes.c_int),
+ ("cfg_interval", ctypes.c_float * 2),
+ ("timestep_shift", ctypes.c_float),
("cfg_norm", ctypes.c_int),
("past_kvcache", PastKVCachePageList),
("past_kvcache_text", PastKVCachePageList),
@@ -47,11 +49,13 @@ class X2IParams(ctypes.Structure):
_width: int = 1024
_height: int = 1024
_steps: int = 50
- _guidance_scale: float = 7.0
- _image_guidance_scale: float = 7.0
+ _guidance_scale: float = 4.0
+ _image_guidance_scale: float = 1.0
_seed: int = 42
_num_images: int = 1
- _cfg_norm: CfgNormType = CfgNormType.NONE
+ _cfg_norm: CfgNormType = CfgNormType.GLOBAL
+ _cfg_interval: float = (-1, 2)
+ _timestep_shift: float = 3.0
def init(self, **kwargs):
def _get(key, default):
@@ -66,6 +70,8 @@ def _get(key, default):
self.seed = _get("seed", X2IParams._seed)
self.num_images = _get("num_images", X2IParams._num_images)
self.cfg_norm = _get("cfg_norm", X2IParams._cfg_norm)
+ self.cfg_interval = _get("cfg_interval", X2IParams._cfg_interval)
+ self.timestep_shift = _get("timestep_shift", X2IParams._timestep_shift)
self.past_kvcache = PastKVCachePageList()
self.past_kvcache_text = PastKVCachePageList()
self.past_kvcache_img = PastKVCachePageList()
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index aa590aacea..02ea465e32 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -99,6 +99,7 @@ def __init__(
if args.enable_multimodal_x2i:
from lightllm.server.x2i_server.past_kv_cache_client import PastKVCacheClient
+
self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=True, init_shm_data=False)
self.send_to_x2i = context.socket(zmq.PUSH)
self.send_to_x2i.connect(f"{args.zmq_mode}127.0.0.1:{args.x2i_port}")
@@ -298,12 +299,10 @@ async def generate(
start_time = time.time()
request_headers = request.headers if request is not None else {}
group_request_id = self.alloc_req_id(sampling_params, is_health_req)
-
try:
original_multimodal_params = None
if self.is_multinode_tp_master:
original_multimodal_params = copy.deepcopy(multimodal_params)
-
if self.pd_mode.is_P_or_NORMAL():
await multimodal_params.verify_and_preload(request)
@@ -342,7 +341,6 @@ async def generate(
# 如果 decode 节点的 ready_kv_len 和 prefill encode 的 len(prompt ids) -1 相等,说明不需要进行 prefill
# 直接 raise NixlPrefillNodeStopGenToken
raise NixlPrefillNodeStopGenToken(group_request_id=group_request_id)
-
# 申请资源并存储
alloced_req_indexes = []
while len(alloced_req_indexes) < sampling_params.n:
@@ -372,10 +370,10 @@ async def generate(
img_tokens = sum([img.token_num for img in multimodal_params.images])
img_len = len(multimodal_params.images)
kv_pages = self.past_kv_cache_client.allocate_pages(
- req_obj.request_id, req_obj.input_len, img_tokens, img_len)
+ req_obj.request_id, req_obj.input_len, img_tokens, img_len
+ )
req_obj.past_kv_cache_page_indexes.fill(kv_pages)
-
req_objs.append(req_obj)
logger.debug(
@@ -436,13 +434,13 @@ async def generate(
raise e
return
-
- async def generate_image(self, prompt: str, generation_params: X2IParams, multimodal_params: MultimodalParams, request: Request):
+ async def generate_image(
+ self, prompt: str, generation_params: X2IParams, multimodal_params: MultimodalParams, request: Request
+ ):
generate_req_ids = []
+
async def generation_wrapper(prompt, sample, multimodal, request):
- async for sub_req_id, _, metadata, finish_status in self.generate(
- prompt, sample, multimodal, request
- ):
+ async for sub_req_id, _, metadata, finish_status in self.generate(prompt, sample, multimodal, request):
kv_cache_item: PastKVCacheItem = self.past_kv_cache_client.get_pages_by_req_id(sub_req_id)
if kv_cache_item is None:
raise Exception(f"kv_cache_pages is None for sub_req_id {sub_req_id}")
@@ -457,31 +455,41 @@ async def generation_wrapper(prompt, sample, multimodal, request):
sample_params = SamplingParams()
sample_params.init(self.tokenizer, **{"img_gen_prefill": True})
img_len = len(multimodal_params.images)
+ image_guidance_scale = generation_params.image_guidance_scale
- if img_len > 0:
+ if img_len > 0 and image_guidance_scale != 1.0:
# call it2i
# fix prompt, add tag if img_len greater than s in prompt
+ # this branch is not used now, because image_guidance_scale is always recommended to be 1.0
prompt = self.tokenizer.fix_prompt(prompt, img_len)
- prompt_condition, prompt_text_uncondition, prompt_img_uncondition = self.tokenizer.get_query_for_it2i(prompt)
- (con_gen, text_uncon_gen, img_uncon_gen) = await asyncio.gather(*[
- generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
- generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params.clone(), request),
- generation_wrapper(prompt_img_uncondition, sample_params, MultimodalParams(), request)])
+ prompt_condition, prompt_text_uncondition, prompt_img_uncondition = self.tokenizer.get_query_for_it2i(
+ prompt
+ )
+ (con_gen, text_uncon_gen, img_uncon_gen) = await asyncio.gather(
+ *[
+ generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
+ generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params.clone(), request),
+ generation_wrapper(prompt_img_uncondition, sample_params, MultimodalParams(), request),
+ ]
+ )
generation_params.update_it2i(con_gen, text_uncon_gen, img_uncon_gen)
else:
# call t2i
prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt)
logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}")
- (con_gen, uncon_gen) = await asyncio.gather(*[
- generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
- generation_wrapper(prompt_uncondition, sample_params, multimodal_params, request)])
+ (con_gen, uncon_gen) = await asyncio.gather(
+ *[
+ generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
+ generation_wrapper(prompt_uncondition, sample_params, multimodal_params.clone(), request),
+ ]
+ )
generation_params.update_t2i(con_gen, uncon_gen)
# use the first request id as the gen image request id
x2i_req_id = generate_req_ids[0]
generation_params.request_id = x2i_req_id
- req_status = X2IReqStatus(generation_params, generate_req_ids)
+ req_status = X2IReqStatus(generation_params, generate_req_ids)
self.req_id_to_x2i_reqs[generation_params.request_id] = req_status
# send generation_params to generation server for image generation
@@ -859,7 +867,6 @@ async def loop_for_x2i(self):
except Exception as e:
logger.error(e, exc_info=e)
-
async def handle_loop(self):
self.recycle_event = asyncio.Event()
asyncio.create_task(self.recycle_resource_loop())
@@ -964,9 +971,10 @@ def can_release(self):
return False
return True
+
class X2IReqStatus:
def __init__(self, req_param: X2IParams, req_ids: List[int]):
self.req = X2IParams
self.req_ids = req_ids
self.event = asyncio.Event()
- self.response: X2IResponse = None
\ No newline at end of file
+ self.response: X2IResponse = None
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index e9e5e6de4b..a35ba51b09 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -15,7 +15,7 @@
from lightllm.utils.graceful_utils import graceful_registry
from lightllm.utils.process_check import start_parent_check_thread
from lightllm.utils.envs_utils import get_unique_server_name
-from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease
+from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease, CfgNormType
from lightllm.utils.dist_utils import set_current_device_id
from .past_kv_cache_client import PastKVCacheClient
@@ -61,22 +61,24 @@ async def wait_to_model_ready(self):
self.gen_pipe.create_generator(
config_json=self.args.x2v_gen_model_config,
)
- self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False})
+ self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False})
+
# from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I
# self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device())
pass
async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams):
- past_kv_cache = self._truncate_kv_cache_to_compressed_len(
- past_kv_cache, param.past_kvcache.get_compressed_len()
- )
- past_kv_cache_text = self._truncate_kv_cache_to_compressed_len(
- past_kv_cache_text, param.past_kvcache_text.get_compressed_len()
+ self.gen_pipe.runner.set_inference_params(
+ index_offset_cond=param.past_kvcache.get_compressed_len(),
+ index_offset_uncond=param.past_kvcache_text.get_compressed_len(),
+ cfg_interval=param.cfg_interval,
+ cfg_scale=param.guidance_scale,
+ cfg_norm=CfgNormType(param.cfg_norm).as_str(),
+ timestep_shift=param.timestep_shift,
)
- self.gen_pipe.runner.set_kvcache_t2i(past_kv_cache, past_kv_cache_text)
+ self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text)
image = self.gen_pipe.generate(
seed=param.seed,
- task="t2i",
save_result_path="", # 返回base64,不需要指定路径了
target_shape=[param.height, param.width], # Height, Width
)
@@ -84,19 +86,9 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams
return [image]
async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams):
- past_kv_cache = self._truncate_kv_cache_to_compressed_len(
- past_kv_cache, param.past_kvcache.get_compressed_len()
- )
- past_kv_cache_text = self._truncate_kv_cache_to_compressed_len(
- past_kv_cache_text, param.past_kvcache_text.get_compressed_len()
- )
- past_kv_cache_img = self._truncate_kv_cache_to_compressed_len(
- past_kv_cache_img, param.past_kvcache_img.get_compressed_len()
- )
self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img)
image = self.gen_pipe.generate(
seed=param.seed,
- task="i2i",
save_result_path="", # 返回base64,不需要指定路径了
target_shape=[param.height, param.width], # Height, Width
)
@@ -164,11 +156,6 @@ async def loop_for_netio_req(self):
def clean_up(self):
pass
- def _truncate_kv_cache_to_compressed_len(self, kv: torch.Tensor, compressed_len: int) -> torch.Tensor:
- seq = kv.shape[2]
- n = min(compressed_len, seq)
- return kv[:, :, :n:, :].contiguous()
-
def setup_devices(args: StartArgs):
devices = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
diff --git a/lightllm/server/x2i_server/past_kv_cache_client.py b/lightllm/server/x2i_server/past_kv_cache_client.py
index 9e9f0d99d4..6d08e34a2c 100644
--- a/lightllm/server/x2i_server/past_kv_cache_client.py
+++ b/lightllm/server/x2i_server/past_kv_cache_client.py
@@ -33,6 +33,7 @@ def __init__(self, only_create_meta_data: bool, init_shm_data: bool):
self.free_pages: List[int] = list(range(self.page_num))
self.lock = Lock()
self.cond = Condition(self.lock)
+ print("PastKVCacheClient init, page num: ", self.page_num, flush=True)
if not only_create_meta_data:
if init_shm_data:
@@ -86,12 +87,13 @@ def get_kv_cache_for_x2i(self, page_indexes: List[int], token_num: int) -> Optio
# (P, L, S, H, D) -> (P, L, S, 2, H // 2, D) -> (L, 2,P, S, H // 2, D) -> (L, 2, P * S, H // 2, D)
kv = (
self.cpu_kv_cache_tensor[page_indexes]
+ .contiguous()
.view(P, L, S, 2, H // 2, D)
.permute(1, 3, 0, 2, 4, 5)
.contiguous()
.view(L, 2, P * S, H // 2, D)
)
- return kv
+ return kv[:, :, :token_num, :, :].contiguous()
def _create_shm_cpu_kv_cache(self):
shm_ptr = create_shm_kv_cache_ptr(
From 539576b82dba7f0e7cb8ab9e924fcfedf0b61bc9 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Wed, 8 Apr 2026 12:49:19 +0000
Subject: [PATCH 12/41] fix x2v acc, because of seed
---
lightllm/server/x2i_server/manager.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index 5865d47746..b4921b69ee 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -90,7 +90,7 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams
)
self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text)
image = self.gen_pipe.generate(
- seed=param.seed,
+ seed=param.seed + param.past_kvcache.img_len,
save_result_path="", # 返回base64,不需要指定路径了
target_shape=[param.height, param.width], # Height, Width
)
@@ -112,7 +112,7 @@ async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_i
)
self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img)
image = self.gen_pipe.generate(
- seed=param.seed,
+ seed=param.seed + param.past_kvcache_img.img_len,
save_result_path="", # 返回base64,不需要指定路径了
target_shape=[param.height, param.width], # Height, Width
)
From 29dda22014c1ed48fd66b9dad92e4c1cc50c6c17 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Wed, 8 Apr 2026 13:36:20 +0000
Subject: [PATCH 13/41] fix chat template
---
lightllm/server/build_prompt.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py
index 7f16d519a2..a33e35e815 100644
--- a/lightllm/server/build_prompt.py
+++ b/lightllm/server/build_prompt.py
@@ -16,7 +16,10 @@ def init_tokenizer(args):
if chat_path is not None:
with open(chat_path, "r", encoding="utf-8") as f:
chat_template_str = f.read()
- tokenizer.chat_template = chat_template_str
+ if hasattr(tokenizer, "tokenizer"):
+ tokenizer.tokenizer.chat_template = chat_template_str
+ else:
+ tokenizer.chat_template = chat_template_str
return
# 如果 tokenizer 目录下存在chat_template.json, 同时不存在 chat_template.jinja,
From e4effb8d5818daee2ba5caf6fd80ab9b0dce3be6 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Wed, 8 Apr 2026 14:55:46 +0000
Subject: [PATCH 14/41] smart resize
---
lightllm/server/api_models.py | 6 +++++-
lightllm/server/api_openai.py | 2 +-
2 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py
index 643bba66e3..664a248e80 100644
--- a/lightllm/server/api_models.py
+++ b/lightllm/server/api_models.py
@@ -458,7 +458,11 @@ def get_resolution(self):
scale = self._size_multiplier[self.image_size]
w, h = base
- return int(w * scale), int(h * scale)
+ w, h = int(w * scale), int(h * scale)
+ from lightllm.models.neo_chat_moe.vision_process import smart_resize
+
+ h, w = smart_resize(h, w, factor=32, min_pixels=512 * 512, max_pixels=2048 * 2048)
+ return w, h
class ChatCompletionRequestV2(ChatCompletionRequest):
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index f446f3e536..e9603f008d 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -636,7 +636,7 @@ async def _chat_completion_image_only(
response_images = _message_contents_from_raw_images(images, request.image_config.image_type)
chat_message = ChatMessage(
role="assistant",
- content=prompt,
+ content="",
images=response_images if response_images else None,
)
choice = ChatCompletionResponseChoice(
From a71817ae83e15af1f137704f452158e60bf4747a Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Thu, 9 Apr 2026 05:50:43 +0000
Subject: [PATCH 15/41] nit.
---
lightllm/server/api_openai.py | 2 +-
.../router/model_infer/mode_backend/past_kv_cache.py | 3 ---
lightllm/server/x2i_server/manager.py | 7 ++++---
3 files changed, 5 insertions(+), 7 deletions(-)
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index e9603f008d..52fe59aa39 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -863,7 +863,7 @@ async def stream_result() -> AsyncGenerator[bytes, None]:
prompt, x2i_params, multimodal_params.clone(), request=raw_request
)
- if len(images) == 0:
+ if images is None or len(images) == 0:
logger.warning(f"No image generated by x2i: {prompt[-100:]}, exit...")
break
diff --git a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
index de80da5d01..9035de219b 100644
--- a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
+++ b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
@@ -140,9 +140,6 @@ def update_past_kv_cache_task_states(self):
self.past_kv_cache_task.appendleft(task)
break
- if len(trans_ok_tasks) == 0:
- return
-
ok_tasks_num = torch.tensor(len(trans_ok_tasks))
dist.all_reduce(ok_tasks_num, op=dist.ReduceOp.MIN, group=self.sync_task_status_group)
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index b4921b69ee..d8ed2a4c7d 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -1,11 +1,11 @@
import zmq
-import zmq.asyncio
import asyncio
import uvloop
import inspect
import setproctitle
import pickle
import torch
+import time
import os
from typing import List
from lightllm.server.core.objs import StartArgs
@@ -149,10 +149,12 @@ async def loop_for_fwd(self):
images = []
logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {x2i_param}")
+ start_t = time.time()
if is_t2i:
images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param)
else:
images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param)
+ logger.info(f"generate {len(images)} images done, cost {time.time() - start_t:.2f}s")
self.send_to_httpserver.send_pyobj(
X2IResponse(request_id=x2i_param.request_id, images=images), protocol=pickle.HIGHEST_PROTOCOL
@@ -162,8 +164,7 @@ async def loop_for_fwd(self):
self.send_to_httpserver.send_pyobj(
X2IResponse(request_id=x2i_param.request_id, images=None), protocol=pickle.HIGHEST_PROTOCOL
)
-
- logger.error(e)
+ logger.error(e, exc_info=e)
async def loop_for_netio_req(self):
while True:
From 044158c89008ddbcb46b8831c028d821f717923c Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Thu, 9 Apr 2026 10:05:08 +0000
Subject: [PATCH 16/41] fix device.
---
lightllm/server/api_start.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index 5ce80bf419..e034d9f437 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -366,6 +366,8 @@ def normal_or_p_d_start(args):
)
if origin_devices:
os.environ["CUDA_VISIBLE_DEVICES"] = origin_devices
+ else:
+ os.environ.pop("CUDA_VISIBLE_DEVICES", None)
if args.enable_cpu_cache:
from .multi_level_kv_cache.manager import start_multi_level_kv_cache_manager
From 067a6795302ef15bad1162c7146e78f0e26ea255 Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Fri, 10 Apr 2026 09:15:08 +0000
Subject: [PATCH 17/41] support distributed lightx2v.
---
lightllm/server/api_start.py | 10 +-
lightllm/server/core/objs/start_args_type.py | 2 +
lightllm/server/httpserver/manager.py | 4 +-
.../server/x2i_server/lightx2v/adapter.py | 145 ++++++++++++++++++
lightllm/server/x2i_server/manager.py | 133 ++++++++++------
5 files changed, 239 insertions(+), 55 deletions(-)
create mode 100644 lightllm/server/x2i_server/lightx2v/adapter.py
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index e034d9f437..6413ec22b0 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -250,7 +250,7 @@ def normal_or_p_d_start(args):
node_world_size = args.tp // args.nnodes
can_use_ports = alloc_can_use_network_port(
- num=12 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports
+ num=14 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports
)
logger.info(f"alloced ports: {can_use_ports}")
(
@@ -266,8 +266,10 @@ def normal_or_p_d_start(args):
pd_decode_rpyc_port,
x2i_port,
http_server_port_for_x2i,
- ) = can_use_ports[0:12]
- can_use_ports = can_use_ports[12:]
+ x2i_worker_nccl_port,
+ x2i_worker_task_port,
+ ) = can_use_ports[0:14]
+ can_use_ports = can_use_ports[14:]
visual_model_tp_ports = []
visual_nccl_ports = []
@@ -294,6 +296,8 @@ def normal_or_p_d_start(args):
args.visual_nccl_ports = visual_nccl_ports
args.x2i_port = x2i_port
args.http_server_port_for_x2i = http_server_port_for_x2i
+ args.x2i_worker_task_port = x2i_worker_task_port
+ args.x2i_worker_nccl_port = x2i_worker_nccl_port
# 申请在 p d 分离模式下,会用的端口
args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size]
# p d 分离模式下用于标识节点的id
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index 17cdeed06c..ece47a2bf1 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -166,6 +166,8 @@ class StartArgs:
http_server_port_for_x2i: int = field(default=None)
x2i_server_used_gpus: int = field(default=1)
x2i_use_naive_impl: bool = field(default=False)
+ x2i_worker_task_port: int = field(default=None)
+ x2i_worker_nccl_port: int = field(default=None)
# multi_modal
enable_multimodal_x2i: bool = field(default=False)
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 02ea465e32..6ca8e86bc6 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -104,7 +104,7 @@ def __init__(
self.send_to_x2i = context.socket(zmq.PUSH)
self.send_to_x2i.connect(f"{args.zmq_mode}127.0.0.1:{args.x2i_port}")
self.recv_from_x2i = context.socket(zmq.PULL)
- self.recv_from_x2i.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}")
+ self.recv_from_x2i.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}")
self.req_id_to_x2i_reqs: Dict[int, X2IReqStatus] = {}
self.shm_req_manager = ShmReqManager()
@@ -477,7 +477,7 @@ async def generation_wrapper(prompt, sample, multimodal, request):
else:
# call t2i
prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt)
- logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}")
+ # logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}")
(con_gen, uncon_gen) = await asyncio.gather(
*[
generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py
new file mode 100644
index 0000000000..cd0de51d88
--- /dev/null
+++ b/lightllm/server/x2i_server/lightx2v/adapter.py
@@ -0,0 +1,145 @@
+import inspect
+import torch
+import torch.distributed as dist
+import zmq
+import setproctitle
+import asyncio
+import os
+
+from lightllm.server.core.objs import StartArgs
+from lightllm.utils.log_utils import init_logger
+from lightllm.utils.graceful_utils import graceful_registry
+from lightllm.utils.process_check import start_parent_check_thread
+from lightllm.utils.envs_utils import get_unique_server_name
+from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease, CfgNormType
+from ..past_kv_cache_client import PastKVCacheClient
+
+logger = init_logger(__name__)
+
+
+class LightX2VServer:
+ def __init__(self, args: StartArgs, rank: int, world_size: int):
+ self.args = args
+ self.rank = rank
+ self.world_size = world_size
+
+ context = zmq.Context(2)
+
+ # receive task from manager
+ self.task_socket = context.socket(zmq.SUB)
+ self.task_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.x2i_worker_task_port}")
+ self.task_socket.setsockopt(zmq.SUBSCRIBE, b"")
+
+ if self.rank == 0:
+ # send result back
+ self.result_socket = context.socket(zmq.PUSH)
+ self.result_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.http_server_port_for_x2i}")
+
+ self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=False)
+ torch.cuda.set_device(rank)
+ self._init_pipeline()
+
+ def _init_pipeline(self):
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
+ os.environ["MASTER_PORT"] = str(self.args.x2i_worker_nccl_port)
+ os.environ["RANK"] = str(self.rank)
+ os.environ["WORLD_SIZE"] = str(self.world_size)
+
+ from lightx2v import LightX2VPipeline
+
+ self.pipe = LightX2VPipeline(
+ model_path=self.args.model_dir,
+ model_cls="neopp",
+ support_tasks=["t2i", "i2i"],
+ )
+ self.pipe.create_generator(config_json=self.args.x2v_gen_model_config)
+ self.pipe.modify_config({
+ "load_kv_cache_in_pipeline_for_debug": False,
+ "save_result_for_debug": False})
+
+ async def run(self):
+ while True:
+ param: X2IParams = self.task_socket.recv_pyobj()
+
+ try:
+ images = self._process(param)
+
+ if self.rank == 0:
+ self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=images))
+
+ except Exception as e:
+ logger.error(f"Error processing request {param.request_id}: {str(e)}", exc_info=e)
+ if self.rank == 0:
+ self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=None))
+
+ def _process(self, param: X2IParams):
+ is_t2i = param.past_kvcache_img.is_empty()
+
+ self.pipe.runner.set_inference_params(
+ index_offset_cond=param.past_kvcache.get_compressed_len(),
+ index_offset_uncond=param.past_kvcache_text.get_compressed_len(),
+ cfg_interval=param.cfg_interval,
+ cfg_scale=param.guidance_scale,
+ cfg_norm=CfgNormType(param.cfg_norm).as_str(),
+ timestep_shift=param.timestep_shift,
+ )
+ past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i(
+ param.past_kvcache.get_all(), param.past_kvcache.token_len)
+ past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i(
+ param.past_kvcache_text.get_all(), param.past_kvcache_text.token_len)
+ past_kv_cache_img = None
+ if not is_t2i:
+ past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i(
+ param.past_kvcache_img.get_all(), param.past_kvcache_img.token_len)
+
+ dist.barrier() # ensure all workers have got the kv cache before generation starts
+
+ if self.rank == 0:
+ # release
+ self.result_socket.send_pyobj(X2ICacheRelease(request_id=param.request_id))
+
+ if is_t2i:
+ self.pipe.runner.set_kvcache(
+ past_kv_cache,
+ past_kv_cache_text,
+ )
+ else:
+ self.pipe.runner.set_kvcache_i2i(
+ past_kv_cache,
+ past_kv_cache_text,
+ past_kv_cache_img,
+ )
+ image = self.pipe.generate(
+ seed=param.seed + param.past_kvcache.img_len,
+ save_result_path="",
+ target_shape=[param.height, param.width],
+ )
+
+ return [image]
+
+
+def start_x2v_process(args: StartArgs, rank: int, world_size: int, pipe_writer):
+
+ # 注册graceful 退出的处理
+ graceful_registry(inspect.currentframe().f_code.co_name)
+ setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::x2v_server_{rank}")
+ start_parent_check_thread()
+
+ try:
+ x2v_server = LightX2VServer(args=args, rank=rank, world_size=world_size)
+ except Exception as e:
+ logger.exception(str(e), exc_info=e)
+ raise e
+
+ pipe_writer.send("init ok")
+
+ def handle_exception(loop, context):
+ logger.exception(f"X2VServer Caught exception: {str(context)}")
+
+ loop = asyncio.new_event_loop()
+ loop.set_exception_handler(handle_exception)
+ asyncio.set_event_loop(loop)
+
+ loop.run_until_complete(x2v_server.run())
+
+ return
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index d8ed2a4c7d..8b86e12030 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -6,6 +6,7 @@
import pickle
import torch
import time
+import multiprocessing as mp
import os
from typing import List
from lightllm.server.core.objs import StartArgs
@@ -17,6 +18,7 @@
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease, CfgNormType
from lightllm.utils.dist_utils import set_current_device_id
+from lightllm.utils.start_utils import start_submodule_processes
from .past_kv_cache_client import PastKVCacheClient
logger = init_logger(__name__)
@@ -29,8 +31,23 @@
3. call llm gen to obtain past key values
4. call x2v to generate images and pass the key values to it
5. return the generated images.
-"""
+ +-------------------+
+ | X2IManager |
+ +---------+---------+
+ |
+ (broadcast)
+ |
+ +------+------+------+
+ | | |
+ Worker0 Worker1 ...
+ (rank0) (rank1)
+ | |
+ +------ allreduce / sync ----+
+ |
+ only rank0
+ returns result
+"""
class X2IManager:
def __init__(
@@ -40,40 +57,51 @@ def __init__(
context = zmq.Context(2)
self.args = args
+ # from http server
self.zmq_recv_socket = context.socket(zmq.PULL)
self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.x2i_port}")
+ # to http server
self.send_to_httpserver = context.socket(zmq.PUSH)
- self.send_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}")
+ self.send_to_httpserver.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port_for_x2i}")
+
+ self.use_naive_x2i = args.x2i_use_naive_impl
+ self.world_size = args.x2i_server_used_gpus
+
+ if not self.use_naive_x2i and self.world_size > 1:
+ # send to workers
+ self.worker_pub = context.socket(zmq.PUB)
+ self.worker_pub.bind(f"{args.zmq_mode}127.0.0.1:{args.x2i_worker_task_port}")
self.waiting_reqs: List[X2IParams] = []
self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True)
- self.use_naive_x2i = args.x2i_use_naive_impl
async def wait_to_model_ready(self):
- if self.use_naive_x2i:
- from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I
- self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device())
- return
-
- from lightx2v import LightX2VPipeline
-
- self.gen_pipe = LightX2VPipeline(
- model_path=self.args.model_dir,
- model_cls="neopp",
- support_tasks=["t2i", "i2i"],
- )
- self.gen_pipe.create_generator(
- config_json=self.args.x2v_gen_model_config,
- )
- self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False})
-
- # from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I
- # self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device())
- pass
+ if self.world_size <= 1:
+ if self.use_naive_x2i:
+ from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I
+ self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device())
+ else:
+ from lightx2v import LightX2VPipeline
+
+ self.gen_pipe = LightX2VPipeline(
+ model_path=self.args.model_dir,
+ model_cls="neopp",
+ support_tasks=["t2i", "i2i"],
+ )
+ self.gen_pipe.create_generator(
+ config_json=self.args.x2v_gen_model_config,
+ )
+ self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False})
+ else:
+ # distribted x2v
+ from lightllm.server.x2i_server.lightx2v.adapter import start_x2v_process
+ funcs = [start_x2v_process] * self.world_size
+ args = [(self.args, rank, self.world_size) for rank in range(self.world_size)]
+ start_submodule_processes(funcs, args)
async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams):
if self.use_naive_x2i:
@@ -127,38 +155,42 @@ async def loop_for_fwd(self):
x2i_param = self.waiting_reqs.pop(0)
- past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i(
- x2i_param.past_kvcache.get_all(), x2i_param.past_kvcache.token_len, self.use_naive_x2i
- )
-
- past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i(
- x2i_param.past_kvcache_text.get_all(), x2i_param.past_kvcache_text.token_len, self.use_naive_x2i
- )
- is_t2i = x2i_param.past_kvcache_img.is_empty()
+ if not self.use_naive_x2i and self.world_size > 1:
+ # broadcast to workers
+ self.worker_pub.send_pyobj(x2i_param, protocol=pickle.HIGHEST_PROTOCOL)
+ else:
+ past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i(
+ x2i_param.past_kvcache.get_all(), x2i_param.past_kvcache.token_len, self.use_naive_x2i
+ )
- past_kv_cache_img = None
- if not is_t2i: # t2i
- past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i(
- x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len, self.use_naive_x2i
+ past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i(
+ x2i_param.past_kvcache_text.get_all(), x2i_param.past_kvcache_text.token_len, self.use_naive_x2i
)
+ is_t2i = x2i_param.past_kvcache_img.is_empty()
- # release
- self.send_to_httpserver.send_pyobj(
- X2ICacheRelease(request_id=x2i_param.request_id), protocol=pickle.HIGHEST_PROTOCOL
- )
+ past_kv_cache_img = None
+ if not is_t2i: # t2i
+ past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i(
+ x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len, self.use_naive_x2i
+ )
- images = []
- logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {x2i_param}")
- start_t = time.time()
- if is_t2i:
- images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param)
- else:
- images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param)
- logger.info(f"generate {len(images)} images done, cost {time.time() - start_t:.2f}s")
+ # release
+ self.send_to_httpserver.send_pyobj(
+ X2ICacheRelease(request_id=x2i_param.request_id), protocol=pickle.HIGHEST_PROTOCOL
+ )
- self.send_to_httpserver.send_pyobj(
- X2IResponse(request_id=x2i_param.request_id, images=images), protocol=pickle.HIGHEST_PROTOCOL
- )
+ images = []
+ logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {x2i_param}")
+ start_t = time.time()
+ if is_t2i:
+ images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param)
+ else:
+ images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param)
+ logger.info(f"generate {len(images)} images done, cost {time.time() - start_t:.2f}s")
+
+ self.send_to_httpserver.send_pyobj(
+ X2IResponse(request_id=x2i_param.request_id, images=images), protocol=pickle.HIGHEST_PROTOCOL
+ )
except Exception as e:
self.send_to_httpserver.send_pyobj(
@@ -190,6 +222,7 @@ def setup_devices(args: StartArgs):
devices = [int(x.strip()) for x in devices.split(",") if x.strip()]
llm_need_gpus = args.tp * args.dp
+ # llm_need_gpus = 0
x2i_need_gpus = args.x2i_server_used_gpus
if len(devices) < llm_need_gpus + x2i_need_gpus:
raise ValueError(f"devices {devices} not enough, need {llm_need_gpus} and {x2i_need_gpus}")
From 6cf6a101138410d5d9d565e08d505e4574a45151 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Mon, 13 Apr 2026 03:22:44 +0000
Subject: [PATCH 18/41] change task distribute from pub/sub to push/pull
---
.../server/x2i_server/lightx2v/adapter.py | 35 ++++++++++++-------
lightllm/server/x2i_server/manager.py | 18 +---------
2 files changed, 24 insertions(+), 29 deletions(-)
diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py
index cd0de51d88..47472a7c53 100644
--- a/lightllm/server/x2i_server/lightx2v/adapter.py
+++ b/lightllm/server/x2i_server/lightx2v/adapter.py
@@ -2,6 +2,7 @@
import torch
import torch.distributed as dist
import zmq
+import zmq.asyncio
import setproctitle
import asyncio
import os
@@ -23,14 +24,13 @@ def __init__(self, args: StartArgs, rank: int, world_size: int):
self.rank = rank
self.world_size = world_size
- context = zmq.Context(2)
-
# receive task from manager
- self.task_socket = context.socket(zmq.SUB)
- self.task_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.x2i_worker_task_port}")
- self.task_socket.setsockopt(zmq.SUBSCRIBE, b"")
-
if self.rank == 0:
+ context = zmq.asyncio.Context(2)
+ self.task_socket = context.socket(zmq.PULL)
+ self.task_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.x2i_worker_task_port}")
+ # self.task_socket.setsockopt(zmq.SUBSCRIBE, b"")
+
# send result back
self.result_socket = context.socket(zmq.PUSH)
self.result_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.http_server_port_for_x2i}")
@@ -39,6 +39,8 @@ def __init__(self, args: StartArgs, rank: int, world_size: int):
torch.cuda.set_device(rank)
self._init_pipeline()
+ self.task_dist_group = dist.new_group(backend="gloo")
+
def _init_pipeline(self):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(self.args.x2i_worker_nccl_port)
@@ -59,20 +61,29 @@ def _init_pipeline(self):
async def run(self):
while True:
- param: X2IParams = self.task_socket.recv_pyobj()
try:
- images = self._process(param)
+ if self.rank == 0:
+ param: X2IParams = await self.task_socket.recv_pyobj()
+ dist.broadcast_object_list([param], src=0, group=self.task_dist_group)
+ else:
+ params = [None]
+ dist.broadcast_object_list(params, src=0, group=self.task_dist_group)
+ param: X2IParams = params[0]
+
+ assert param is not None, "Received None param in x2v worker, this should not happen."
+
+ images = await self._process(param)
if self.rank == 0:
- self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=images))
+ await self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=images))
except Exception as e:
logger.error(f"Error processing request {param.request_id}: {str(e)}", exc_info=e)
if self.rank == 0:
- self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=None))
+ await self.result_socket.send_pyobj(X2IResponse(request_id=param.request_id, images=None))
- def _process(self, param: X2IParams):
+ async def _process(self, param: X2IParams):
is_t2i = param.past_kvcache_img.is_empty()
self.pipe.runner.set_inference_params(
@@ -96,7 +107,7 @@ def _process(self, param: X2IParams):
if self.rank == 0:
# release
- self.result_socket.send_pyobj(X2ICacheRelease(request_id=param.request_id))
+ await self.result_socket.send_pyobj(X2ICacheRelease(request_id=param.request_id))
if is_t2i:
self.pipe.runner.set_kvcache(
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index 8b86e12030..f715df74b5 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -31,22 +31,6 @@
3. call llm gen to obtain past key values
4. call x2v to generate images and pass the key values to it
5. return the generated images.
-
- +-------------------+
- | X2IManager |
- +---------+---------+
- |
- (broadcast)
- |
- +------+------+------+
- | | |
- Worker0 Worker1 ...
- (rank0) (rank1)
- | |
- +------ allreduce / sync ----+
- |
- only rank0
- returns result
"""
class X2IManager:
@@ -70,7 +54,7 @@ def __init__(
if not self.use_naive_x2i and self.world_size > 1:
# send to workers
- self.worker_pub = context.socket(zmq.PUB)
+ self.worker_pub = context.socket(zmq.PUSH)
self.worker_pub.bind(f"{args.zmq_mode}127.0.0.1:{args.x2i_worker_task_port}")
self.waiting_reqs: List[X2IParams] = []
From 29850a61a9e4132fc4e226b437c07decad7adda3 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Mon, 13 Apr 2026 05:43:11 +0000
Subject: [PATCH 19/41] use poller.
---
lightllm/server/httpserver/manager.py | 9 ++++++++-
lightllm/server/x2i_server/lightx2v/adapter.py | 1 -
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 6ca8e86bc6..4a4bf937a3 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -842,10 +842,17 @@ async def recycle_resource_loop(self):
return
async def loop_for_x2i(self):
+ poller = zmq.asyncio.Poller()
+ poller.register(self.recv_from_x2i, zmq.POLLIN)
while True:
try:
- recv_obj = await asyncio.wait_for(self.recv_from_x2i.recv_pyobj(), timeout=0.05)
+ events = dict(await poller.poll(timeout=50))
+
+ if self.recv_from_x2i in events:
+ recv_obj = await self.recv_from_x2i.recv_pyobj()
+ else:
+ continue
if isinstance(recv_obj, X2ICacheRelease):
status = self.req_id_to_x2i_reqs[recv_obj.request_id]
diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py
index 47472a7c53..76900a92f2 100644
--- a/lightllm/server/x2i_server/lightx2v/adapter.py
+++ b/lightllm/server/x2i_server/lightx2v/adapter.py
@@ -29,7 +29,6 @@ def __init__(self, args: StartArgs, rank: int, world_size: int):
context = zmq.asyncio.Context(2)
self.task_socket = context.socket(zmq.PULL)
self.task_socket.connect(f"{args.zmq_mode}127.0.0.1:{self.args.x2i_worker_task_port}")
- # self.task_socket.setsockopt(zmq.SUBSCRIBE, b"")
# send result back
self.result_socket = context.socket(zmq.PUSH)
From 40a16b0e12cb7b1f9ede824dda1dc82912bbd1f2 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Mon, 13 Apr 2026 06:38:59 +0000
Subject: [PATCH 20/41] add num_images for x2v
---
lightllm/server/x2i_server/manager.py | 51 +++++++++++++++++----------
1 file changed, 32 insertions(+), 19 deletions(-)
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index f715df74b5..28462da8f1 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -33,6 +33,7 @@
5. return the generated images.
"""
+
class X2IManager:
def __init__(
self,
@@ -61,12 +62,12 @@ def __init__(
self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True)
-
async def wait_to_model_ready(self):
if self.world_size <= 1:
if self.use_naive_x2i:
from lightllm.server.x2i_server.naive.modeling_neo_chat import NEOX2I
+
self.naive_x2i = NEOX2I(self.args.model_dir, torch.cuda.current_device())
else:
from lightx2v import LightX2VPipeline
@@ -79,10 +80,13 @@ async def wait_to_model_ready(self):
self.gen_pipe.create_generator(
config_json=self.args.x2v_gen_model_config,
)
- self.gen_pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False})
+ self.gen_pipe.modify_config(
+ {"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False}
+ )
else:
# distribted x2v
from lightllm.server.x2i_server.lightx2v.adapter import start_x2v_process
+
funcs = [start_x2v_process] * self.world_size
args = [(self.args, rank, self.world_size) for rank in range(self.world_size)]
start_submodule_processes(funcs, args)
@@ -100,14 +104,16 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams
cfg_norm=CfgNormType(param.cfg_norm).as_str(),
timestep_shift=param.timestep_shift,
)
- self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text)
- image = self.gen_pipe.generate(
- seed=param.seed + param.past_kvcache.img_len,
- save_result_path="", # 返回base64,不需要指定路径了
- target_shape=[param.height, param.width], # Height, Width
- )
- # images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param)
- return [image]
+ images = []
+ for i in range(param.num_images):
+ self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text)
+ image = self.gen_pipe.generate(
+ seed=param.seed + param.past_kvcache.img_len + i,
+ save_result_path="", # 返回base64,不需要指定路径了
+ target_shape=[param.height, param.width], # Height, Width
+ )
+ images.append(image)
+ return images
async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams):
if self.use_naive_x2i:
@@ -122,13 +128,16 @@ async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_i
cfg_norm=CfgNormType(param.cfg_norm).as_str(),
timestep_shift=param.timestep_shift,
)
- self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img)
- image = self.gen_pipe.generate(
- seed=param.seed + param.past_kvcache_img.img_len,
- save_result_path="", # 返回base64,不需要指定路径了
- target_shape=[param.height, param.width], # Height, Width
- )
- return [image]
+ images = []
+ for i in range(param.num_images):
+ self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img)
+ image = self.gen_pipe.generate(
+ seed=param.seed + param.past_kvcache_img.img_len + i,
+ save_result_path="", # 返回base64,不需要指定路径了
+ target_shape=[param.height, param.width], # Height, Width
+ )
+ images.append(image)
+ return images
async def loop_for_fwd(self):
while True:
@@ -155,7 +164,9 @@ async def loop_for_fwd(self):
past_kv_cache_img = None
if not is_t2i: # t2i
past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i(
- x2i_param.past_kvcache_img.get_all(), x2i_param.past_kvcache_img.token_len, self.use_naive_x2i
+ x2i_param.past_kvcache_img.get_all(),
+ x2i_param.past_kvcache_img.token_len,
+ self.use_naive_x2i,
)
# release
@@ -169,7 +180,9 @@ async def loop_for_fwd(self):
if is_t2i:
images = await self.t2i_generate(past_kv_cache, past_kv_cache_text, x2i_param)
else:
- images = await self.it2i_generate(past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param)
+ images = await self.it2i_generate(
+ past_kv_cache, past_kv_cache_text, past_kv_cache_img, x2i_param
+ )
logger.info(f"generate {len(images)} images done, cost {time.time() - start_t:.2f}s")
self.send_to_httpserver.send_pyobj(
From af5c710a4aad463a6c8e1f025b33ba6c36925803 Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Mon, 13 Apr 2026 11:18:39 +0000
Subject: [PATCH 21/41] fixup.
---
lightllm/server/api_cli.py | 7 +++++++
lightllm/server/core/objs/start_args_type.py | 1 +
lightllm/server/httpserver/manager.py | 11 ++++++++---
lightllm/server/x2i_server/lightx2v/adapter.py | 3 ++-
lightllm/server/x2i_server/manager.py | 3 +--
5 files changed, 19 insertions(+), 6 deletions(-)
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index 34b7226470..6efb2a8843 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -313,6 +313,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=1,
help="Number of GPUs to use for x2i server (requird --enable_multimodal_x2i).",
)
+ parser.add_argument(
+ "--x2i_server_deploy_mode",
+ type=str,
+ choices=["colocate", "separate"],
+ default="colocate",
+ help="Deployment mode for the x2i server. 'colocate' means the x2i server will run on the same gpus as the llm server, ",
+ )
parser.add_argument(
"--x2i_use_naive_impl",
action="store_true",
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index ece47a2bf1..5599a831ec 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -165,6 +165,7 @@ class StartArgs:
x2i_port: int = field(default=None)
http_server_port_for_x2i: int = field(default=None)
x2i_server_used_gpus: int = field(default=1)
+ x2i_server_deploy_mode: str = field(default="colocate")
x2i_use_naive_impl: bool = field(default=False)
x2i_worker_task_port: int = field(default=None)
x2i_worker_nccl_port: int = field(default=None)
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 4a4bf937a3..0304ecea91 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -454,6 +454,9 @@ async def generation_wrapper(prompt, sample, multimodal, request):
# 1. construct 3 or 2 generate based on the multimodel_parmas
sample_params = SamplingParams()
sample_params.init(self.tokenizer, **{"img_gen_prefill": True})
+ sample_params1 = SamplingParams()
+ sample_params1.init(self.tokenizer, **{"img_gen_prefill": True})
+
img_len = len(multimodal_params.images)
image_guidance_scale = generation_params.image_guidance_scale
@@ -466,11 +469,13 @@ async def generation_wrapper(prompt, sample, multimodal, request):
prompt_condition, prompt_text_uncondition, prompt_img_uncondition = self.tokenizer.get_query_for_it2i(
prompt
)
+ sample_params2 = SamplingParams()
+ sample_params2.init(self.tokenizer, **{"img_gen_prefill": True})
(con_gen, text_uncon_gen, img_uncon_gen) = await asyncio.gather(
*[
generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
- generation_wrapper(prompt_text_uncondition, sample_params, multimodal_params.clone(), request),
- generation_wrapper(prompt_img_uncondition, sample_params, MultimodalParams(), request),
+ generation_wrapper(prompt_text_uncondition, sample_params1, multimodal_params.clone(), request),
+ generation_wrapper(prompt_img_uncondition, sample_params2, MultimodalParams(), request),
]
)
generation_params.update_it2i(con_gen, text_uncon_gen, img_uncon_gen)
@@ -481,7 +486,7 @@ async def generation_wrapper(prompt, sample, multimodal, request):
(con_gen, uncon_gen) = await asyncio.gather(
*[
generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
- generation_wrapper(prompt_uncondition, sample_params, multimodal_params.clone(), request),
+ generation_wrapper(prompt_uncondition, sample_params1, multimodal_params.clone(), request),
]
)
generation_params.update_t2i(con_gen, uncon_gen)
diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py
index 76900a92f2..b55d0d6337 100644
--- a/lightllm/server/x2i_server/lightx2v/adapter.py
+++ b/lightllm/server/x2i_server/lightx2v/adapter.py
@@ -1,3 +1,4 @@
+import datetime
import inspect
import torch
import torch.distributed as dist
@@ -38,7 +39,7 @@ def __init__(self, args: StartArgs, rank: int, world_size: int):
torch.cuda.set_device(rank)
self._init_pipeline()
- self.task_dist_group = dist.new_group(backend="gloo")
+ self.task_dist_group = dist.new_group(backend="gloo", timeout=datetime.timedelta(days=30))
def _init_pipeline(self):
os.environ["MASTER_ADDR"] = "127.0.0.1"
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index 28462da8f1..aef9b0ab88 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -218,8 +218,7 @@ def setup_devices(args: StartArgs):
else:
devices = [int(x.strip()) for x in devices.split(",") if x.strip()]
- llm_need_gpus = args.tp * args.dp
- # llm_need_gpus = 0
+ llm_need_gpus = 0 if args.x2i_server_deploy_mode == "colocate" else args.tp * args.dp
x2i_need_gpus = args.x2i_server_used_gpus
if len(devices) < llm_need_gpus + x2i_need_gpus:
raise ValueError(f"devices {devices} not enough, need {llm_need_gpus} and {x2i_need_gpus}")
From 2e7c525a8d9fb454d34a7f6af209bd69b73a5869 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Mon, 13 Apr 2026 12:17:49 +0000
Subject: [PATCH 22/41] enable thinking & input_image num
---
lightllm/models/neo_chat_moe/model.py | 6 ++++--
lightllm/server/api_models.py | 2 +-
lightllm/server/api_openai.py | 4 ++--
lightllm/server/httpserver/manager.py | 9 +++++++--
4 files changed, 14 insertions(+), 7 deletions(-)
diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py
index 648d696877..dba8965512 100644
--- a/lightllm/models/neo_chat_moe/model.py
+++ b/lightllm/models/neo_chat_moe/model.py
@@ -154,11 +154,13 @@ def get_query_for_it2i(self, prompt: str):
question_img_uncondition = self._build_t2i_query("")
return query_condition, query_text_uncondition, question_img_uncondition
- def get_query_for_t2i(self, prompt: str):
+ def get_query_for_t2i(self, prompt: str, input_image_num: int = 0):
# prompt is already applied
image_len = prompt.count(IMG_TOKEN)
query_condition = prompt + IMG_START_TOKEN if not prompt.endswith(IMG_START_TOKEN) else prompt
- query_uncondition = self._build_t2i_query("", thinking_content=IMG_TOKEN * image_len)
+ query_uncondition = self._build_t2i_query(
+ IMG_TOKEN * input_image_num, thinking_content=IMG_TOKEN * (image_len - input_image_num)
+ )
return query_condition, query_uncondition
diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py
index 664a248e80..c2fce7dd95 100644
--- a/lightllm/server/api_models.py
+++ b/lightllm/server/api_models.py
@@ -418,7 +418,7 @@ class ImageConfig(BaseModel):
"4:5": (896, 1152),
"5:4": (1152, 896),
"9:16": (768, 1344),
- "16:9": (1344, 768),
+ "16:9": (1920, 1080),
"21:9": (1536, 672),
}
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index 52fe59aa39..dbbfbb5528 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -668,7 +668,6 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
if request.chat_template_kwargs is None:
request.chat_template_kwargs = {}
- request.chat_template_kwargs.update({"enable_thinking": False})
chat_request: ChatCompletionRequest = ChatCompletionRequest(**request.model_dump())
@@ -689,6 +688,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
created_time = int(time.time())
prompt, sampling_params, multimodal_params = await _get_text_generator_input(chat_request)
+ input_image_num = len(multimodal_params.images)
x2i_params = X2IParams()
x2i_params.init_from_image_config(request.image_config)
@@ -732,7 +732,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
if need_call_x2i:
prompt += output_chunk
images = await g_objs.httpserver_manager.generate_image(
- prompt, x2i_params, multimodal_params.clone(), request=raw_request
+ prompt, x2i_params, multimodal_params.clone(), request=raw_request, input_image_num=input_image_num
)
if len(images) == 0:
logger.warning(f"No image generated by x2i: {prompt[-100:]}, exit...")
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 0304ecea91..31d445483b 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -435,7 +435,12 @@ async def generate(
return
async def generate_image(
- self, prompt: str, generation_params: X2IParams, multimodal_params: MultimodalParams, request: Request
+ self,
+ prompt: str,
+ generation_params: X2IParams,
+ multimodal_params: MultimodalParams,
+ request: Request,
+ input_image_num: int = 0,
):
generate_req_ids = []
@@ -481,7 +486,7 @@ async def generation_wrapper(prompt, sample, multimodal, request):
generation_params.update_it2i(con_gen, text_uncon_gen, img_uncon_gen)
else:
# call t2i
- prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt)
+ prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt, input_image_num)
# logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}")
(con_gen, uncon_gen) = await asyncio.gather(
*[
From c9d19d011e883aff061dc3f49d22b2c6908083d5 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Mon, 13 Apr 2026 13:14:52 +0000
Subject: [PATCH 23/41] keep the same resolution for it2i
---
lightllm/server/api_openai.py | 3 ++-
lightllm/server/core/objs/x2i_params.py | 12 ++++++++++++
lightllm/server/httpserver/manager.py | 5 +++++
3 files changed, 19 insertions(+), 1 deletion(-)
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index dbbfbb5528..55112d7a4b 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -630,8 +630,9 @@ async def _chat_completion_image_only(
from .api_http import g_objs
created_time = int(time.time())
+ input_image_num = len(multimodal_params.images)
images = await g_objs.httpserver_manager.generate_image(
- prompt, x2i_params, multimodal_params.clone(), request=raw_request
+ prompt, x2i_params, multimodal_params.clone(), request=raw_request, input_image_num=input_image_num
)
response_images = _message_contents_from_raw_images(images, request.image_config.image_type)
chat_message = ChatMessage(
diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py
index e0f1bdb30c..281e5f66d2 100644
--- a/lightllm/server/core/objs/x2i_params.py
+++ b/lightllm/server/core/objs/x2i_params.py
@@ -78,6 +78,7 @@ def _get(key, default):
self.past_kvcache_img = PastKVCachePageList()
self.total_prompt_tokens = 0
self.request_id = 0
+ self.has_updated_hw = False
def init_from_image_config(self, image_config: Any) -> None:
"""从 HTTP `image_config`(api_models.ImageConfig)填充,与 `init(**kwargs)` 共用默认值逻辑。"""
@@ -104,6 +105,17 @@ def init_from_image_config(self, image_config: Any) -> None:
break
self.init(**kwargs)
+ def update_hw(self, width: int, height: int):
+ if self.has_updated_hw:
+ return
+ from lightllm.models.neo_chat_moe.vision_process import smart_resize
+
+ h, w = smart_resize(height, width, factor=32, min_pixels=512 * 512, max_pixels=2048 * 2048)
+ self.width = w
+ self.height = h
+ self.has_updated_hw = True
+ return
+
def update(self, past_kv: PastKVCachePageList, meta: Dict):
item: PastKVCacheItem = meta.get("kv_cache_item")
past_kv.token_len = item.token_len
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 31d445483b..99ed011cb1 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -495,6 +495,11 @@ async def generation_wrapper(prompt, sample, multimodal, request):
]
)
generation_params.update_t2i(con_gen, uncon_gen)
+ if input_image_num > 0:
+ # for it2i, the output image size is the same as the input image size
+ generation_params.update_hw(
+ multimodal_params.images[0].image_w, multimodal_params.images[0].image_h
+ )
# use the first request id as the gen image request id
x2i_req_id = generate_req_ids[0]
generation_params.request_id = x2i_req_id
From d0967b70dbc6ef9ba9a5adb6fe1656ad3e52473c Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Tue, 14 Apr 2026 10:58:57 +0000
Subject: [PATCH 24/41] fixup.
---
.../triton_kernel/kv_cache_offload.py | 2 +-
.../server/router/model_infer/infer_batch.py | 4 ++
.../model_infer/mode_backend/past_kv_cache.py | 39 +++++++++++++++----
3 files changed, 36 insertions(+), 9 deletions(-)
diff --git a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
index 686c0c1790..01d32e908c 100644
--- a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
+++ b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
@@ -916,7 +916,7 @@ def offload_gpu_kv_to_cpu_for_x2i(
assert gpu_heads == 1
assert cpu_heads == 1
- need_offload == tp_index == 0
+ need_offload = tp_index == 0
cpu_k_start_head_index = 0
cpu_k_head_num = 1
gpu_k_start_head_index = 0
diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py
index 0088c60d65..898a48f87a 100644
--- a/lightllm/server/router/model_infer/infer_batch.py
+++ b/lightllm/server/router/model_infer/infer_batch.py
@@ -590,6 +590,10 @@ def handle(
# detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。
shm_req.shm_cur_output_len = self.output_len
+ if shm_req.sample_params.img_gen_prefill:
+ # img gen prefill 需要等待 kv cache 卸载到 cpu 后才更新detokenization需要的信息
+ return
+
if finish_status.is_finished():
shm_req.finish_token_index = shm_req.input_len + self.output_len - 1
shm_req.finish_status = req_obj.finish_status
diff --git a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
index 9035de219b..9d0a8d4dd4 100644
--- a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
+++ b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
@@ -20,17 +20,19 @@
class TransTask:
req_obj: InferReq
sync_event: torch.cuda.Event
-
+ buffer_index: int
class PastKVCacheModule(object):
def __init__(self, backend):
from .base_backend import ModeBackend
self.backend: ModeBackend = backend
self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=False)
- self.page_index_buffer = torch.empty((LIGHTLLM_TOKEN_HASH_LIST_SIZE * 2,), dtype=torch.int32, device="cuda")
+ self.page_index_buffer = torch.empty((1024 * LIGHTLLM_TOKEN_HASH_LIST_SIZE,), dtype=torch.int32, device="cuda")
+ self.page_index_buffer_free_index = list(range(1024))
self.past_kv_cache_task: Deque[TransTask] = deque()
self.sync_task_status_group = create_new_group_for_current_dp("gloo")
+
@lru_cache()
def need_sync_compute_stream(self) -> bool:
"""
@@ -51,7 +53,6 @@ def need_sync_compute_stream(self) -> bool:
logger.info("PastKVCacheModule: no need sync compute stream.")
return False
-
def offload_finished_reqs_to_past_kv_cache(self, finished_reqs: List[InferReq]) -> List[InferReq]:
"""
Offload the finished reqs to past kv cache, and return the truly finished reqs that can be freed in infer batch.
@@ -70,28 +71,37 @@ def offload_finished_reqs_to_past_kv_cache(self, finished_reqs: List[InferReq])
if req.past_kv_cache_task_status.is_running():
continue
- assert req.past_kv_cache_task_status.is_not_started()
+ assert req.past_kv_cache_task_status.is_not_started(), \
+ f"req {req.req_id} has invalid past kv cache task status {req.past_kv_cache_task_status}"
if self.need_sync_compute_stream():
g_infer_context.get_overlap_stream().synchronize()
trans_task = self._start_kv_cache_offload(req=req)
- assert trans_task is not None
+ assert trans_task is not None, f"req {req.req_id} start kv cache offload failed"
self.past_kv_cache_task.append(trans_task)
-
return true_finished_reqs
def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]:
with torch.cuda.stream(g_infer_context.get_cpu_kv_cache_stream()):
+ if len(self.page_index_buffer_free_index) == 0:
+ raise RuntimeError("No free page index for offloading past kv cache to CPU.")
+
+ assert req.shm_req.past_kv_cache_page_indexes.size <= LIGHTLLM_TOKEN_HASH_LIST_SIZE
+
+ free_index = self.page_index_buffer_free_index.pop(0)
+ start = free_index * LIGHTLLM_TOKEN_HASH_LIST_SIZE
+ end = start + req.shm_req.past_kv_cache_page_indexes.size
+
page_indexes = torch.tensor(req.shm_req.past_kv_cache_page_indexes.get_all(), dtype=torch.int32, device='cpu', pin_memory=True)
num_tokens = req.shm_req.input_len
assert req.cur_kv_len >= num_tokens
assert num_tokens <= len(page_indexes) * self.past_kv_cache_client.token_page_size
- cuda_page_indexes = self.page_index_buffer[:len(page_indexes)]
+ cuda_page_indexes = self.page_index_buffer[start:end]
cuda_page_indexes.copy_(page_indexes)
token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0: num_tokens]
@@ -102,7 +112,7 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]:
cpu_cache_meta = self.past_kv_cache_client.kv_cache_tensor_meta
cpu_kv_cache = self.past_kv_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0:cpu_cache_meta.head_dim]
cpu_kv_cache_scale = self.past_kv_cache_client.cpu_kv_cache_tensor[
- :, :, :, :, cpu_cache_meta.head_dim
+ :, :, :, :, cpu_cache_meta.head_dim :
].view(mem_manager.scale_buffer.dtype)
gpu_kv_cache_scale = mem_manager.scale_buffer
else:
@@ -124,10 +134,12 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]:
)
sync_event = torch.cuda.Event()
sync_event.record()
+ # sync_event.synchronize()
req.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING
return TransTask(
req_obj=req,
sync_event=sync_event,
+ buffer_index=free_index
)
def update_past_kv_cache_task_states(self):
@@ -148,3 +160,14 @@ def update_past_kv_cache_task_states(self):
self.past_kv_cache_task.extendleft(reversed(unfinished))
for task in finished:
task.req_obj.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED
+ self.page_index_buffer_free_index.append(task.buffer_index)
+
+ if self.backend.is_master_in_dp:
+ shm_req = task.req_obj.shm_req
+ assert task.req_obj.finish_status.is_finished()
+ shm_req.finish_token_index = shm_req.input_len + shm_req.shm_cur_output_len - 1
+ shm_req.finish_status = task.req_obj.finish_status
+ shm_req.candetoken_out_len = shm_req.shm_cur_output_len
+ else:
+ if len(trans_ok_tasks) > 0:
+ self.past_kv_cache_task.extendleft(reversed(trans_ok_tasks))
\ No newline at end of file
From f6ae0d1974a7194e607adc8ec6eac9291f8d1460 Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Wed, 15 Apr 2026 06:34:36 +0000
Subject: [PATCH 25/41] workaround for illegal memory access.
---
.../server/router/model_infer/mode_backend/past_kv_cache.py | 1 +
lightllm/server/x2i_server/lightx2v/adapter.py | 2 ++
2 files changed, 3 insertions(+)
diff --git a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
index 9d0a8d4dd4..e6457ded9a 100644
--- a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
+++ b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
@@ -134,6 +134,7 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]:
)
sync_event = torch.cuda.Event()
sync_event.record()
+ sync_event.wait(g_infer_context.get_overlap_stream())
# sync_event.synchronize()
req.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING
return TransTask(
diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py
index b55d0d6337..b738b00d7d 100644
--- a/lightllm/server/x2i_server/lightx2v/adapter.py
+++ b/lightllm/server/x2i_server/lightx2v/adapter.py
@@ -109,6 +109,8 @@ async def _process(self, param: X2IParams):
# release
await self.result_socket.send_pyobj(X2ICacheRelease(request_id=param.request_id))
+ logger.info(f"{'t2i' if is_t2i else 'it2i'} generate images with: {param}")
+
if is_t2i:
self.pipe.runner.set_kvcache(
past_kv_cache,
From 20439fefdeb56159816e96fe30741266253b6e1f Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Mon, 20 Apr 2026 08:05:57 +0000
Subject: [PATCH 26/41] fix attention
---
.../context_attention_fwd_neo.py | 2 +-
.../test_context_attention_fwd_neo.py | 345 ++++++++++++++++++
2 files changed, 346 insertions(+), 1 deletion(-)
create mode 100644 unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
index 74ff82cae4..a745b7c9da 100644
--- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
+++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
@@ -103,7 +103,7 @@ def _fwd_kernel(
qk += tl.dot(q, k)
# mask: causal OR same gid (only possible inside NEW part)
- mask = (q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None]
+ mask = ((q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None]) & k_valid[None, :]
qk = tl.where(mask, qk * sm_scale, -1.0e8)
# online softmax
diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
new file mode 100644
index 0000000000..4abfde5a83
--- /dev/null
+++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
@@ -0,0 +1,345 @@
+"""Unit test for ``context_attention_fwd_neo``.
+
+Torch reference expresses the *semantics* of the attention, not the kernel's
+internal block structure — it has no notion of BLOCK_N / BLOCK_M. For each
+batch element we gather K/V for the whole request (prompt + new tokens) via
+``req_to_token_indexs`` and apply::
+
+ allow[m, k] = (k <= q_pos[m]) OR image_tag[m] for k in [0, total)
+
+i.e. normal queries are causal, image-token queries can see every real key in
+the request. If the Triton kernel disagrees with this reference, the kernel is
+wrong.
+
+Run directly for quick debugging:
+
+ python unit_tests/common/basemodel/triton_kernel/att/prefill_att/\
+ test_context_attention_fwd_neo.py
+
+or via pytest:
+
+ pytest unit_tests/common/basemodel/triton_kernel/att/prefill_att/\
+ test_context_attention_fwd_neo.py -x -s
+"""
+
+import math
+import pytest
+import torch
+
+from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import (
+ context_attention_fwd_neo,
+)
+
+
+def torch_reference_context_attention_neo(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ b_req_idx: torch.Tensor,
+ b_start_loc: torch.Tensor,
+ b_seq_len: torch.Tensor,
+ b_prompt_cache_len: torch.Tensor,
+ req_to_token_indexs: torch.Tensor,
+ b_image_token_tag: torch.Tensor,
+) -> torch.Tensor:
+ device = q.device
+ dtype = q.dtype
+ _, Hq, D = q.shape
+ Hk = k.shape[1]
+ kv_group = Hq // Hk
+ scale = 1.0 / math.sqrt(D)
+
+ out = torch.empty_like(q)
+
+ for b in range(b_seq_len.shape[0]):
+ req_idx = int(b_req_idx[b].item())
+ total = int(b_seq_len[b].item())
+ prompt = int(b_prompt_cache_len[b].item())
+ new = total - prompt
+ if new <= 0:
+ continue
+
+ q_start = int(b_start_loc[b].item())
+ q_blk = q[q_start : q_start + new] # [M, Hq, D]
+ image_tag = b_image_token_tag[q_start : q_start + new].to(torch.bool)
+
+ token_locs = req_to_token_indexs[req_idx, :total].to(torch.int64)
+ k_blk = k[token_locs] # [total, Hk, D]
+ v_blk = v[token_locs]
+
+ q_pos = torch.arange(prompt, total, device=device, dtype=torch.int64) # [M]
+ k_pos = torch.arange(0, total, device=device, dtype=torch.int64) # [total]
+ causal = k_pos[None, :] <= q_pos[:, None]
+ allow = causal | image_tag[:, None]
+
+ out_blk = torch.empty_like(q_blk)
+ for h in range(Hq):
+ h_k = h // kv_group
+ q_h = q_blk[:, h, :].to(torch.float32)
+ k_h = k_blk[:, h_k, :].to(torch.float32)
+ v_h = v_blk[:, h_k, :].to(torch.float32)
+
+ scores = (q_h @ k_h.transpose(0, 1)) * scale
+ scores = torch.where(allow, scores, torch.full_like(scores, -1.0e8))
+ probs = torch.softmax(scores, dim=-1)
+ out_h = (probs @ v_h).to(dtype)
+ out_blk[:, h, :] = out_h
+
+ out[q_start : q_start + new] = out_blk
+
+ return out
+
+
+def _build_inputs(
+ batch: int,
+ Hq: int,
+ Hk: int,
+ D: int,
+ dtype: torch.dtype,
+ device: str,
+ max_new: int = 256,
+ max_prompt: int = 512,
+ image_prob: float = 0.7,
+ num_image_spans_max: int = 3,
+ image_span_len_max: int = 24,
+ kv_pool_slack: int = 4096,
+ seed: int = 0,
+):
+ g = torch.Generator(device="cpu").manual_seed(seed)
+
+ new_lens = torch.randint(low=1, high=max_new + 1, size=(batch,), generator=g)
+ prompt_lens = torch.randint(low=0, high=max_prompt + 1, size=(batch,), generator=g)
+ total_lens = new_lens + prompt_lens
+
+ sum_new = int(new_lens.sum().item())
+ sum_total = int(total_lens.sum().item())
+ max_total_len = int(total_lens.max().item())
+ max_new_len = int(new_lens.max().item())
+
+ b_start_loc = torch.zeros(batch, dtype=torch.int32)
+ cur = 0
+ for i in range(batch):
+ b_start_loc[i] = cur
+ cur += int(new_lens[i].item())
+
+ # Permute so batch idx != request idx: exercises the Req_to_tokens indexing.
+ b_req_idx = torch.randperm(batch, generator=g).to(torch.int32)
+
+ # Global KV pool with scattered, non-contiguous slot assignment per request.
+ base = 1024
+ kv_pool_size = base + sum_total + kv_pool_slack
+ pool = torch.randperm(kv_pool_size - base, generator=g)[:sum_total] + base
+
+ req_to_token_indexs = torch.zeros((batch, max_total_len), dtype=torch.int32)
+ p = 0
+ for r_logical, req_id in enumerate(b_req_idx.tolist()):
+ L = int(total_lens[r_logical].item())
+ req_to_token_indexs[req_id, :L] = pool[p : p + L].to(torch.int32)
+ p += L
+
+ # Randomly place contiguous image-token spans inside each batch's NEW region.
+ b_image_token_tag = torch.zeros(sum_new, dtype=torch.bool)
+ for i in range(batch):
+ M = int(new_lens[i].item())
+ if M < 2:
+ continue
+ if torch.rand((), generator=g).item() > image_prob:
+ continue
+ n_spans = int(torch.randint(1, num_image_spans_max + 1, (1,), generator=g).item())
+ start_pack = int(b_start_loc[i].item())
+ for _ in range(n_spans):
+ span_len = int(
+ torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item()
+ )
+ span_len = min(span_len, M)
+ s_rel = int(torch.randint(0, M - span_len + 1, (1,), generator=g).item())
+ b_image_token_tag[start_pack + s_rel : start_pack + s_rel + span_len] = True
+
+ b_seq_len = total_lens.to(torch.int32)
+ b_prompt_cache_len = prompt_lens.to(torch.int32)
+
+ # position_ids[0]: kernel API still requires it even though its current
+ # mask logic only reads b_image_token_tag.
+ position_ids_0 = torch.empty(sum_new, dtype=torch.int32)
+ for i in range(batch):
+ M = int(new_lens[i].item())
+ P = int(prompt_lens[i].item())
+ s = int(b_start_loc[i].item())
+ position_ids_0[s : s + M] = torch.arange(P, P + M, dtype=torch.int32)
+
+ q = torch.randn((sum_new, Hq, D), dtype=dtype, device=device)
+ k = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device)
+ v = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device)
+ o = torch.empty_like(q)
+
+ return dict(
+ q=q,
+ k=k,
+ v=v,
+ o=o,
+ position_ids_0=position_ids_0.to(device),
+ b_req_idx=b_req_idx.to(device),
+ b_start_loc=b_start_loc.to(device),
+ b_seq_len=b_seq_len.to(device),
+ b_prompt_cache_len=b_prompt_cache_len.to(device),
+ max_new_len=max_new_len,
+ req_to_token_indexs=req_to_token_indexs.to(device),
+ b_image_token_tag=b_image_token_tag.to(device),
+ new_lens=new_lens,
+ prompt_lens=prompt_lens,
+ )
+
+
+def _report_per_batch_error(out_triton, out_ref, new_lens, b_start_loc, image_tag, tag=""):
+ print(f"\n[{tag}] per-batch error breakdown (abs / rel / cos):")
+ for i in range(new_lens.shape[0]):
+ s = int(b_start_loc[i].item())
+ m = int(new_lens[i].item())
+ if m == 0:
+ continue
+ a = out_triton[s : s + m].float()
+ b = out_ref[s : s + m].float()
+ abs_err = (a - b).abs().max().item()
+ denom = b.abs().max().item() + 1e-6
+ rel_err = abs_err / denom
+ cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+ n_img = int(image_tag[s : s + m].sum().item())
+ print(
+ f" batch {i:02d} | M={m:4d} | image_tokens={n_img:4d} | "
+ f"max_abs={abs_err:.4e} | max_rel={rel_err:.4e} | cos={cos:.6f}"
+ )
+
+
+def _run_case(
+ batch: int,
+ Hq: int,
+ Hk: int,
+ D: int,
+ dtype: torch.dtype,
+ seed: int,
+ max_new: int,
+ max_prompt: int,
+ atol: float = 5e-2,
+ rtol: float = 5e-2,
+ cos_threshold: float = 0.99,
+ verbose: bool = True,
+):
+ assert Hq % Hk == 0
+ device = "cuda"
+
+ inputs = _build_inputs(
+ batch=batch,
+ Hq=Hq,
+ Hk=Hk,
+ D=D,
+ dtype=dtype,
+ device=device,
+ max_new=max_new,
+ max_prompt=max_prompt,
+ seed=seed,
+ )
+
+ context_attention_fwd_neo(
+ inputs["q"],
+ inputs["k"],
+ inputs["v"],
+ inputs["o"],
+ inputs["position_ids_0"],
+ inputs["b_req_idx"],
+ inputs["b_start_loc"],
+ inputs["b_seq_len"],
+ inputs["b_prompt_cache_len"],
+ inputs["max_new_len"],
+ inputs["req_to_token_indexs"],
+ inputs["b_image_token_tag"],
+ )
+ out_triton = inputs["o"]
+
+ out_ref = torch_reference_context_attention_neo(
+ inputs["q"],
+ inputs["k"],
+ inputs["v"],
+ inputs["b_req_idx"],
+ inputs["b_start_loc"],
+ inputs["b_seq_len"],
+ inputs["b_prompt_cache_len"],
+ inputs["req_to_token_indexs"],
+ inputs["b_image_token_tag"],
+ )
+
+ a = out_triton.float()
+ b = out_ref.float()
+ abs_err = (a - b).abs().max().item()
+ denom = b.abs().max().item() + 1e-6
+ rel_err = abs_err / denom
+ cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+
+ n_image = int(inputs["b_image_token_tag"].sum().item())
+ n_tokens = int(inputs["b_image_token_tag"].numel())
+ if verbose:
+ print(
+ f"\ncase: batch={batch} Hq={Hq} Hk={Hk} D={D} dtype={dtype} "
+ f"seed={seed} image_tokens={n_image}/{n_tokens}"
+ )
+ print(
+ f" global: max_abs={abs_err:.4e} max_rel={rel_err:.4e} cos={cos:.6f} "
+ f"(allclose atol={atol}, rtol={rtol}? "
+ f"{torch.allclose(a, b, atol=atol, rtol=rtol)})"
+ )
+ _report_per_batch_error(
+ out_triton,
+ out_ref,
+ inputs["new_lens"],
+ inputs["b_start_loc"],
+ inputs["b_image_token_tag"],
+ tag=f"seed={seed}",
+ )
+
+ return abs_err, rel_err, cos
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA")
+@pytest.mark.parametrize(
+ "batch,Hq,Hk,D,dtype,seed,max_new,max_prompt",
+ [
+ (4, 8, 2, 128, torch.bfloat16, 0, 128, 256),
+ (4, 8, 2, 128, torch.bfloat16, 1, 256, 512),
+ (8, 16, 4, 128, torch.bfloat16, 2, 256, 512),
+ (16, 28, 4, 128, torch.bfloat16, 3, 128, 256),
+ (4, 8, 2, 128, torch.float16, 4, 256, 512),
+ (4, 8, 8, 64, torch.bfloat16, 5, 128, 256),
+ (3, 8, 2, 128, torch.bfloat16, 6, 8, 1024),
+ ],
+)
+def test_context_attention_fwd_neo(batch, Hq, Hk, D, dtype, seed, max_new, max_prompt):
+ abs_err, rel_err, cos = _run_case(
+ batch=batch,
+ Hq=Hq,
+ Hk=Hk,
+ D=D,
+ dtype=dtype,
+ seed=seed,
+ max_new=max_new,
+ max_prompt=max_prompt,
+ verbose=True,
+ )
+ assert cos > 0.99, f"cosine similarity too low: {cos}"
+ assert rel_err < 5e-2, f"max relative error too large: {rel_err}"
+
+
+if __name__ == "__main__":
+ if not torch.cuda.is_available():
+ print("No CUDA available.")
+ raise SystemExit(0)
+
+ torch.manual_seed(0)
+
+ cases = [
+ dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.bfloat16, seed=0, max_new=128, max_prompt=256),
+ dict(batch=8, Hq=16, Hk=4, D=128, dtype=torch.bfloat16, seed=1, max_new=256, max_prompt=512),
+ dict(batch=16, Hq=28, Hk=4, D=128, dtype=torch.bfloat16, seed=2, max_new=128, max_prompt=256),
+ dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.float16, seed=3, max_new=256, max_prompt=512),
+ ]
+
+ for cfg in cases:
+ _run_case(**cfg, verbose=True)
From 35906f5b2422e0bc3f3600c9fdd6788020c25330 Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Mon, 20 Apr 2026 14:02:58 +0000
Subject: [PATCH 27/41] pass unit test
---
.../att/prefill_att/test_fa3_neo.py | 548 ++++++++++++++++++
1 file changed, 548 insertions(+)
create mode 100644 unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py
diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py
new file mode 100644
index 0000000000..79b1e8e343
--- /dev/null
+++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py
@@ -0,0 +1,548 @@
+"""Unit test for the FA3-based prefill path with image-token support.
+
+This test pre-wires a call to ``flash_attn_with_kvcache`` with an
+``image_token_tag`` keyword argument. The expectation is that ``fa3-neo``'s
+``flash_attn_with_kvcache`` will be extended with an optional
+``image_token_tag`` parameter that, for queries flagged as image tokens,
+relaxes the causal mask so they can attend bidirectionally to every real key
+in the request.
+
+Torch reference expresses the *semantics* of the attention, not FA3's internal
+tiling — it has no notion of BLOCK_N / BLOCK_M. For each batch element we
+gather K/V for the whole request (prompt + new tokens) and apply::
+
+ allow[m, k] = (k <= q_pos[m]) OR image_tag[m] for k in [0, total)
+
+i.e. normal queries are causal, image-token queries can see every real key in
+the request. If FA3 disagrees with this reference, the kernel is wrong.
+
+Run directly for quick debugging:
+
+ python unit_tests/common/basemodel/triton_kernel/att/prefill_att/\
+ test_fa3_neo.py
+
+or via pytest:
+
+ pytest unit_tests/common/basemodel/triton_kernel/att/prefill_att/\
+ test_fa3_neo.py -x -s
+"""
+
+import math
+import pytest
+import torch
+
+from flash_attn_interface import flash_attn_with_kvcache
+
+try:
+ import triton
+ import triton.testing as triton_testing
+except ImportError:
+ triton = None
+ triton_testing = None
+
+try:
+ from lightllm.models.neo_chat_moe.triton_kernel.context_attention_fwd_neo import (
+ context_attention_fwd_neo,
+ )
+except ImportError:
+ context_attention_fwd_neo = None
+
+
+def torch_reference_context_attention_neo(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ b_req_idx: torch.Tensor,
+ b_q_start_loc: torch.Tensor,
+ b_seq_len: torch.Tensor,
+ b_prompt_cache_len: torch.Tensor,
+ req_to_token_indexs: torch.Tensor,
+ b_image_token_tag: torch.Tensor,
+) -> torch.Tensor:
+ device = q.device
+ dtype = q.dtype
+ _, Hq, D = q.shape
+ Hk = k.shape[1]
+ kv_group = Hq // Hk
+ scale = 1.0 / math.sqrt(D)
+
+ out = torch.empty_like(q)
+
+ for b in range(b_seq_len.shape[0]):
+ req_idx = int(b_req_idx[b].item())
+ seq_len = int(b_seq_len[b].item())
+ prompt_cache_len = int(b_prompt_cache_len[b].item())
+ q_seq_len = seq_len - prompt_cache_len
+ if q_seq_len <= 0:
+ continue
+
+ q_start = int(b_q_start_loc[b].item())
+ q_blk = q[q_start : q_start + q_seq_len] # [M, Hq, D]
+ image_tag = b_image_token_tag[q_start : q_start + q_seq_len].to(torch.bool)
+
+ token_locs = req_to_token_indexs[req_idx, :seq_len].to(torch.int64)
+ k_blk = k[token_locs] # [seq_len, Hk, D]
+ v_blk = v[token_locs]
+
+ q_pos = torch.arange(prompt_cache_len, seq_len, device=device, dtype=torch.int64) # [M]
+ k_pos = torch.arange(0, seq_len, device=device, dtype=torch.int64) # [seq_len]
+ causal = k_pos[None, :] <= q_pos[:, None]
+ allow = causal | image_tag[:, None]
+
+ out_blk = torch.empty_like(q_blk)
+ for h in range(Hq):
+ h_k = h // kv_group
+ q_h = q_blk[:, h, :].to(torch.float32)
+ k_h = k_blk[:, h_k, :].to(torch.float32)
+ v_h = v_blk[:, h_k, :].to(torch.float32)
+
+ scores = (q_h @ k_h.transpose(0, 1)) * scale
+ scores = torch.where(allow, scores, torch.full_like(scores, -1.0e8))
+ probs = torch.softmax(scores, dim=-1)
+ out_h = (probs @ v_h).to(dtype)
+ out_blk[:, h, :] = out_h
+
+ out[q_start : q_start + q_seq_len] = out_blk
+
+ return out
+
+
+def _build_inputs(
+ batch: int,
+ Hq: int,
+ Hk: int,
+ D: int,
+ dtype: torch.dtype,
+ device: str,
+ max_q_seq_len: int = 256,
+ max_prompt_cache_len: int = 512,
+ image_prob: float = 0.7,
+ num_image_spans_max: int = 3,
+ image_span_len_max: int = 24,
+ kv_pool_slack: int = 4096,
+ seed: int = 0,
+):
+ """Build one realistic prefill batch.
+
+ Naming matches lightllm's infer_state:
+ - ``q_seq_len`` = number of new Q tokens in this prefill call
+ - ``prompt_cache_len`` = length of the already-cached prefix for this req
+ - ``seq_len`` = prompt_cache_len + q_seq_len (total KV length)
+ """
+ g = torch.Generator(device="cpu").manual_seed(seed)
+
+ q_seq_lens = torch.randint(low=1, high=max_q_seq_len + 1, size=(batch,), generator=g)
+ prompt_cache_lens = torch.randint(low=0, high=max_prompt_cache_len + 1, size=(batch,), generator=g)
+ seq_lens = q_seq_lens + prompt_cache_lens
+
+ sum_q = int(q_seq_lens.sum().item())
+ sum_total = int(seq_lens.sum().item())
+ max_seq_len_in_batch = int(seq_lens.max().item())
+ max_q_seq_len_in_batch = int(q_seq_lens.max().item())
+
+ b_q_start_loc = torch.zeros(batch, dtype=torch.int32)
+ cur = 0
+ for i in range(batch):
+ b_q_start_loc[i] = cur
+ cur += int(q_seq_lens[i].item())
+
+ # Permute so batch idx != request idx: exercises the page_table indexing.
+ b_req_idx = torch.randperm(batch, generator=g).to(torch.int32)
+
+ # Global KV pool with scattered, non-contiguous slot assignment per request.
+ base = 1024
+ kv_pool_size = base + sum_total + kv_pool_slack
+ pool = torch.randperm(kv_pool_size - base, generator=g)[:sum_total] + base
+
+ req_to_token_indexs = torch.zeros((batch, max_seq_len_in_batch), dtype=torch.int32)
+ p = 0
+ for r_logical, req_id in enumerate(b_req_idx.tolist()):
+ L = int(seq_lens[r_logical].item())
+ req_to_token_indexs[req_id, :L] = pool[p : p + L].to(torch.int32)
+ p += L
+
+ # Randomly place contiguous image-token spans inside each batch's new-Q region.
+ b_image_token_tag = torch.zeros(sum_q, dtype=torch.bool)
+ for i in range(batch):
+ M = int(q_seq_lens[i].item())
+ if M < 2:
+ continue
+ if torch.rand((), generator=g).item() > image_prob:
+ continue
+ n_spans = int(torch.randint(1, num_image_spans_max + 1, (1,), generator=g).item())
+ start_pack = int(b_q_start_loc[i].item())
+ for _ in range(n_spans):
+ span_len = int(
+ torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item()
+ )
+ span_len = min(span_len, M)
+ s_rel = int(torch.randint(0, M - span_len + 1, (1,), generator=g).item())
+ b_image_token_tag[start_pack + s_rel : start_pack + s_rel + span_len] = True
+
+ b_seq_len = seq_lens.to(torch.int32)
+ b_prompt_cache_len = prompt_cache_lens.to(torch.int32)
+
+ # Per-batch last image-token index in *batch-local* packed-q coordinates.
+ # Matches NeoChatInferStateInfo._compute_b_max_image_q_idx semantics:
+ # shape int32[batch]; value == -1 means that batch has no image tokens.
+ # Computed on CPU (b_image_token_tag is still on CPU here) so no D2H sync
+ # — keeps the eventual flash_attn_with_kvcache call CUDA-graph-safe.
+ b_max_image_q_idx_cpu = torch.full((batch,), -1, dtype=torch.int32)
+ for b in range(batch):
+ start = int(b_q_start_loc[b].item())
+ length = int(q_seq_lens[b].item())
+ seg = b_image_token_tag[start : start + length]
+ idx = torch.nonzero(seg, as_tuple=False)
+ if idx.numel() > 0:
+ b_max_image_q_idx_cpu[b] = int(idx[-1, 0].item())
+
+ q = torch.randn((sum_q, Hq, D), dtype=dtype, device=device)
+ k = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device)
+ v = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device)
+
+ return dict(
+ q=q,
+ k=k,
+ v=v,
+ b_req_idx=b_req_idx.to(device),
+ b_q_start_loc=b_q_start_loc.to(device),
+ b_seq_len=b_seq_len.to(device),
+ b_prompt_cache_len=b_prompt_cache_len.to(device),
+ max_seq_len_in_batch=max_seq_len_in_batch,
+ max_q_seq_len_in_batch=max_q_seq_len_in_batch,
+ req_to_token_indexs=req_to_token_indexs.to(device),
+ b_image_token_tag=b_image_token_tag.to(device),
+ b_max_image_q_idx=b_max_image_q_idx_cpu.to(device),
+ q_seq_lens=q_seq_lens,
+ prompt_cache_lens=prompt_cache_lens,
+ )
+
+
+def _fa3_prefill_with_image_tag(inputs: dict) -> torch.Tensor:
+ """Drive ``flash_attn_with_kvcache`` with the same prefill semantics as
+ ``Fa3PrefillAttState._nomarl_prefill_att`` plus an optional
+ ``image_token_tag`` kwarg for image-token bidirectional attention.
+ """
+ q = inputs["q"]
+ k = inputs["k"]
+ v = inputs["v"]
+ device = q.device
+
+ # Build page_table[b, p] = req_to_token_indexs[b_req_idx[b], p].
+ page_table = inputs["req_to_token_indexs"][
+ inputs["b_req_idx"].long(), : inputs["max_seq_len_in_batch"]
+ ].to(torch.int32)
+
+ q_seq_lens_t = inputs["b_seq_len"].to(torch.int32) - inputs["b_prompt_cache_len"].to(torch.int32)
+
+ cu_seqlens_q = torch.zeros(q_seq_lens_t.shape[0] + 1, dtype=torch.int32, device=device)
+ cu_seqlens_q[1:] = q_seq_lens_t.cumsum(0).to(torch.int32)
+
+ cu_seqlens_k = torch.zeros(inputs["b_seq_len"].shape[0] + 1, dtype=torch.int32, device=device)
+ cu_seqlens_k[1:] = inputs["b_seq_len"].cumsum(0).to(torch.int32)
+
+ sm_scale = 1.0 / math.sqrt(q.shape[-1])
+
+ # page_size = 1 paged KV cache view.
+ k_cache = k.view(k.shape[0], 1, k.shape[1], k.shape[2])
+ v_cache = v.view(v.shape[0], 1, v.shape[1], v.shape[2])
+
+ o = flash_attn_with_kvcache(
+ q=q,
+ k_cache=k_cache,
+ v_cache=v_cache,
+ page_table=page_table,
+ cache_seqlens=inputs["b_seq_len"],
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k_new=cu_seqlens_k,
+ max_seqlen_q=inputs["max_q_seq_len_in_batch"],
+ softmax_scale=sm_scale,
+ causal=True,
+ window_size=(-1, -1),
+ softcap=0.0,
+ k_descale=None,
+ v_descale=None,
+ return_softmax_lse=False,
+ # image-token bidirectional attention. Packed like q (shape [sum_q],
+ # bool). Rows where the tag is True are allowed to attend to every
+ # real key in the request (not just the causal prefix).
+ #
+ # b_max_image_q_idx is int32[batch]: per-batch last image-token index
+ # in batch-local coordinates, -1 if no image in that batch. Lets the
+ # fa3 kernel skip n_block_max extension for text-only requests, which
+ # is critical for mixed-modality batches. Pre-computed in
+ # _build_inputs (host-side) so this call is CUDA-graph safe.
+ image_token_tag=inputs["b_image_token_tag"],
+ max_image_q_idx=inputs["b_max_image_q_idx"],
+ )
+ return o
+
+
+def _report_per_batch_error(out_fa3, out_ref, q_seq_lens, b_q_start_loc, image_tag, tag=""):
+ print(f"\n[{tag}] per-batch error breakdown (abs / rel / cos):")
+ for i in range(q_seq_lens.shape[0]):
+ s = int(b_q_start_loc[i].item())
+ m = int(q_seq_lens[i].item())
+ if m == 0:
+ continue
+ a = out_fa3[s : s + m].float()
+ b = out_ref[s : s + m].float()
+ abs_err = (a - b).abs().max().item()
+ denom = b.abs().max().item() + 1e-6
+ rel_err = abs_err / denom
+ cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+ n_img = int(image_tag[s : s + m].sum().item())
+ print(
+ f" batch {i:02d} | M={m:4d} | image_tokens={n_img:4d} | "
+ f"max_abs={abs_err:.4e} | max_rel={rel_err:.4e} | cos={cos:.6f}"
+ )
+
+
+def _run_case(
+ batch: int,
+ Hq: int,
+ Hk: int,
+ D: int,
+ dtype: torch.dtype,
+ seed: int,
+ max_q_seq_len: int,
+ max_prompt_cache_len: int,
+ atol: float = 5e-2,
+ rtol: float = 5e-2,
+ cos_threshold: float = 0.99,
+ verbose: bool = True,
+):
+ assert Hq % Hk == 0
+ device = "cuda"
+
+ inputs = _build_inputs(
+ batch=batch,
+ Hq=Hq,
+ Hk=Hk,
+ D=D,
+ dtype=dtype,
+ device=device,
+ max_q_seq_len=max_q_seq_len,
+ max_prompt_cache_len=max_prompt_cache_len,
+ seed=seed,
+ )
+
+ out_fa3 = _fa3_prefill_with_image_tag(inputs)
+
+ out_ref = torch_reference_context_attention_neo(
+ inputs["q"],
+ inputs["k"],
+ inputs["v"],
+ inputs["b_req_idx"],
+ inputs["b_q_start_loc"],
+ inputs["b_seq_len"],
+ inputs["b_prompt_cache_len"],
+ inputs["req_to_token_indexs"],
+ inputs["b_image_token_tag"],
+ )
+
+ a = out_fa3.float().reshape_as(out_ref.float())
+ b = out_ref.float()
+ abs_err = (a - b).abs().max().item()
+ denom = b.abs().max().item() + 1e-6
+ rel_err = abs_err / denom
+ cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
+
+ n_image = int(inputs["b_image_token_tag"].sum().item())
+ n_tokens = int(inputs["b_image_token_tag"].numel())
+ if verbose:
+ print(
+ f"\ncase: batch={batch} Hq={Hq} Hk={Hk} D={D} dtype={dtype} "
+ f"seed={seed} image_tokens={n_image}/{n_tokens}"
+ )
+ print(
+ f" global: max_abs={abs_err:.4e} max_rel={rel_err:.4e} cos={cos:.6f} "
+ f"(allclose atol={atol}, rtol={rtol}? "
+ f"{torch.allclose(a, b, atol=atol, rtol=rtol)})"
+ )
+ _report_per_batch_error(
+ out_fa3,
+ out_ref,
+ inputs["q_seq_lens"],
+ inputs["b_q_start_loc"],
+ inputs["b_image_token_tag"],
+ tag=f"seed={seed}",
+ )
+
+ return abs_err, rel_err, cos
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA")
+@pytest.mark.skipif(flash_attn_with_kvcache is None, reason="fa3 not available")
+@pytest.mark.parametrize(
+ "batch,Hq,Hk,D,dtype,seed,max_q_seq_len,max_prompt_cache_len",
+ [
+ (4, 8, 2, 128, torch.bfloat16, 0, 128, 256),
+ (4, 8, 2, 128, torch.bfloat16, 1, 256, 512),
+ (8, 16, 4, 128, torch.bfloat16, 2, 256, 512),
+ (16, 28, 4, 128, torch.bfloat16, 3, 128, 256),
+ (4, 8, 2, 128, torch.float16, 4, 256, 512),
+ (4, 8, 8, 64, torch.bfloat16, 5, 128, 256),
+ (3, 8, 2, 128, torch.bfloat16, 6, 8, 1024),
+ ],
+)
+def test_fa3_neo_prefill_with_image_tag(
+ batch, Hq, Hk, D, dtype, seed, max_q_seq_len, max_prompt_cache_len
+):
+ abs_err, rel_err, cos = _run_case(
+ batch=batch,
+ Hq=Hq,
+ Hk=Hk,
+ D=D,
+ dtype=dtype,
+ seed=seed,
+ max_q_seq_len=max_q_seq_len,
+ max_prompt_cache_len=max_prompt_cache_len,
+ verbose=True,
+ )
+ assert cos > 0.99, f"cosine similarity too low: {cos}"
+ assert rel_err < 5e-2, f"max relative error too large: {rel_err}"
+
+
+def _bench_case(
+ batch: int,
+ Hq: int,
+ Hk: int,
+ D: int,
+ dtype: torch.dtype,
+ seed: int,
+ max_q_seq_len: int,
+ max_prompt_cache_len: int,
+ rep_ms: int = 100,
+ warmup_iters: int = 3,
+):
+ """Compare FA3 (with image_token_tag) vs the original Triton
+ ``context_attention_fwd_neo`` using ``triton.testing.do_bench_cudagraph``.
+
+ Both kernels are captured into a CUDA graph so scheduling/launch overhead
+ is minimized and the measurement reflects the kernel cost.
+ """
+ assert Hq % Hk == 0
+ device = "cuda"
+
+ inputs = _build_inputs(
+ batch=batch,
+ Hq=Hq,
+ Hk=Hk,
+ D=D,
+ dtype=dtype,
+ device=device,
+ max_q_seq_len=max_q_seq_len,
+ max_prompt_cache_len=max_prompt_cache_len,
+ seed=seed,
+ )
+
+ # --- fa3 runner: output tensor is allocated inside flash_attn_with_kvcache.
+ def fa3_run():
+ return _fa3_prefill_with_image_tag(inputs)
+
+ # --- triton runner: pre-allocate o & position_ids so the graph captures
+ # only the kernel launch.
+ o_triton = torch.empty_like(inputs["q"])
+ # Kernel signature requires position_ids but the current masking path does
+ # not read it; zeros are fine for perf measurement.
+ position_ids_0 = torch.zeros(
+ inputs["q"].shape[0], dtype=torch.int32, device=inputs["q"].device
+ )
+
+ def triton_run():
+ context_attention_fwd_neo(
+ inputs["q"],
+ inputs["k"],
+ inputs["v"],
+ o_triton,
+ position_ids_0,
+ inputs["b_req_idx"],
+ inputs["b_q_start_loc"],
+ inputs["b_seq_len"],
+ inputs["b_prompt_cache_len"],
+ inputs["max_q_seq_len_in_batch"],
+ inputs["req_to_token_indexs"],
+ inputs["b_image_token_tag"],
+ )
+
+ # Warm up outside the graph capture so lazy allocations / autotune happen.
+ for _ in range(warmup_iters):
+ fa3_run()
+ triton_run()
+ torch.cuda.synchronize()
+
+ fa3_ms = triton_testing.do_bench_cudagraph(fa3_run, rep=rep_ms)
+ triton_ms = triton_testing.do_bench_cudagraph(triton_run, rep=rep_ms)
+
+ n_image = int(inputs["b_image_token_tag"].sum().item())
+ n_tokens = int(inputs["b_image_token_tag"].numel())
+ sum_kv = int(inputs["b_seq_len"].sum().item())
+ speedup = triton_ms / fa3_ms if fa3_ms > 0 else float("inf")
+
+ print(
+ f"bench: batch={batch} Hq={Hq} Hk={Hk} D={D} dtype={str(dtype).split('.')[-1]:<8s} "
+ f"max_q_seq_len={max_q_seq_len:4d} max_prompt_cache_len={max_prompt_cache_len:4d} "
+ f"image_tokens={n_image:4d}/{n_tokens:5d} sum_kv={sum_kv:6d} | "
+ f"fa3 {fa3_ms*1000:8.1f} us | triton {triton_ms*1000:8.1f} us | "
+ f"speedup {speedup:5.2f}x"
+ )
+
+ return fa3_ms, triton_ms
+
+
+if __name__ == "__main__":
+ if not torch.cuda.is_available():
+ print("No CUDA available.")
+ raise SystemExit(0)
+ if flash_attn_with_kvcache is None:
+ print("fa3 flash_attn_with_kvcache not available (sgl_kernel missing?).")
+ raise SystemExit(0)
+
+ torch.manual_seed(0)
+
+ cases = [
+ dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.bfloat16, seed=0, max_q_seq_len=128, max_prompt_cache_len=256),
+ dict(batch=8, Hq=16, Hk=4, D=128, dtype=torch.bfloat16, seed=1, max_q_seq_len=256, max_prompt_cache_len=512),
+ dict(batch=16, Hq=28, Hk=4, D=128, dtype=torch.bfloat16, seed=2, max_q_seq_len=128, max_prompt_cache_len=256),
+ dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.float16, seed=3, max_q_seq_len=256, max_prompt_cache_len=512),
+ ]
+
+ print("=" * 100)
+ print("Correctness")
+ print("=" * 100)
+ for cfg in cases:
+ _run_case(**cfg, verbose=True)
+
+ if triton_testing is None or context_attention_fwd_neo is None:
+ print("\nSkipping benchmark: triton or context_attention_fwd_neo not available.")
+ raise SystemExit(0)
+
+ print("\n" + "=" * 100)
+ print("Benchmark (triton.testing.do_bench_cudagraph)")
+ print("=" * 100)
+
+ # Cold-prefill sweep: max_prompt_cache_len=0 so seq_len == q_seq_len.
+ # Head shape matches neo_chat_moe / Qwen3 llm_config:
+ # num_attention_heads = 32, num_key_value_heads = 8, head_dim = 128
+ # (GQA ratio 4:1)
+ bench_batches = [8, 16, 32, 64, 128]
+ bench_q_seq_lens = [1024, 4096, 8192]
+
+ bench_cases = [
+ dict(
+ batch=b,
+ Hq=32,
+ Hk=8,
+ D=128,
+ dtype=torch.bfloat16,
+ seed=0,
+ max_q_seq_len=s,
+ max_prompt_cache_len=0,
+ )
+ for b in bench_batches
+ for s in bench_q_seq_lens
+ ]
+
+ for cfg in bench_cases:
+ _bench_case(**cfg)
From 1c8e7214598108ab3668556d6117b9b94c47a2e4 Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Mon, 20 Apr 2026 14:18:46 +0000
Subject: [PATCH 28/41] add fa3_neo
---
.../common/basemodel/attention/base_att.py | 5 ++
lightllm/common/basemodel/attention/fa3/fp.py | 10 ++++
.../layer_infer/transformer_layer_infer.py | 54 +++++++++++++------
lightllm/models/neo_chat/model.py | 13 +++++
lightllm/models/neo_chat_moe/infer_struct.py | 10 ++++
.../layer_infer/transformer_layer_infer.py | 52 ++++++++++++------
lightllm/models/neo_chat_moe/model.py | 13 +++++
.../triton_kernel/get_neo_position.py | 27 ++++++++++
8 files changed, 152 insertions(+), 32 deletions(-)
diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py
index 1286a46ec2..1ee38a1d03 100644
--- a/lightllm/common/basemodel/attention/base_att.py
+++ b/lightllm/common/basemodel/attention/base_att.py
@@ -70,6 +70,11 @@ class AttControl:
nsa_prefill_dict: Dict = None
nsa_decode: bool = False
nsa_decode_dict: Dict = None
+ image_token_tag: Optional[torch.Tensor] = None
+ # max_image_q_idx: int32[batch]. 每个 batch 内最后一个 image token 的 local q
+ # 下标;没有 image token 的 batch 填 -1。fa3 kernel 据此决定该 batch 的哪些
+ # M-block 需要把 n_block_max 延长到 full seqlen_k,避免无 image 的 batch 翻倍计算。
+ max_image_q_idx: Optional[torch.Tensor] = None
@dataclass
diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py
index 952bb39d91..3c99c4d095 100644
--- a/lightllm/common/basemodel/attention/fa3/fp.py
+++ b/lightllm/common/basemodel/attention/fa3/fp.py
@@ -92,6 +92,15 @@ def _nomarl_prefill_att(
k_descale, v_descale = None, None # disable quantization
Lq = q.shape[-1]
sm_scale = 1.0 / (Lq ** 0.5)
+
+ # Optional image-token bidirectional attention (neo_chat*). When None,
+ # the kernel path reduces to plain causal fa3 (zero overhead).
+ extra_kwargs = {}
+ if att_control.image_token_tag is not None:
+ extra_kwargs["image_token_tag"] = att_control.image_token_tag
+ if att_control.max_image_q_idx is not None:
+ extra_kwargs["max_image_q_idx"] = att_control.max_image_q_idx
+
o = flash_attn_with_kvcache(
q=q,
k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
@@ -109,6 +118,7 @@ def _nomarl_prefill_att(
v_descale=v_descale,
return_softmax_lse=False,
sinks=sink_weight,
+ **extra_kwargs,
)
return o
diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
index 4517b5688a..0628411cbf 100644
--- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
@@ -1,3 +1,4 @@
+import os
import torch
from functools import partial
from typing import Tuple
@@ -12,6 +13,8 @@
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.common.basemodel.attention.base_att import AttControl
+_USE_TRITON_PREFILL = os.environ.get("LIGHTLLM_NEO_PREFILL_TRITON_BACKEND", "0").strip().lower() in ("1", "true")
+
class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer):
def __init__(self, data_type, network_config):
@@ -83,23 +86,42 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC
def _context_attention_kernel(
self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
) -> torch.Tensor:
- o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
- kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
- context_attention_fwd_neo(
- q.view(-1, self.tp_q_head_num_, self.head_dim_),
- kv[:, 0 : self.tp_k_head_num_, :],
- kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
- o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
- infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
- infer_state.b_req_idx,
- infer_state.b_q_start_loc,
- infer_state.b_seq_len,
- infer_state.b_ready_cache_len,
- infer_state.max_q_seq_len,
- infer_state.req_manager.req_to_token_indexs,
- infer_state.b_image_token_tag,
+
+
+ if _USE_TRITON_PREFILL:
+ o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
+ kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
+ context_attention_fwd_neo(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ kv[:, 0 : self.tp_k_head_num_, :],
+ kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
+ o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
+ infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
+ infer_state.b_req_idx,
+ infer_state.b_q_start_loc,
+ infer_state.b_seq_len,
+ infer_state.b_ready_cache_len,
+ infer_state.max_q_seq_len,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_image_token_tag,
+ )
+ return o_tensor
+
+ _q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
+ _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_)
+
+ att_control = AttControl()
+ att_control.image_token_tag = getattr(infer_state, "b_image_token_tag", None)
+ att_control.max_image_q_idx = getattr(infer_state, "b_max_image_q_idx", None)
+
+ o_tensor = infer_state.prefill_att_state.prefill_att(
+ q=_q,
+ k=_k,
+ v=_v,
+ att_control=att_control,
+ alloc_func=self.alloc_tensor,
)
- return o_tensor
+ return o_tensor.view(q.shape)
def _token_attention_kernel(
self,
diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py
index 14d9f96dc7..e17ad31fa1 100644
--- a/lightllm/models/neo_chat/model.py
+++ b/lightllm/models/neo_chat/model.py
@@ -20,6 +20,11 @@
from lightllm.models.neo_chat.layer_weights.pre_and_post_layer_weight import NeoChatPreAndPostLayerWeight
from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+from lightllm.common.basemodel.attention import (
+ get_prefill_att_backend_class,
+ get_decode_att_backend_class,
+ BaseAttBackend,
+)
@ModelRegistry(["neo_chat"], is_multimodal=True, condition=llm_model_type_is("qwen3"))
@@ -40,6 +45,14 @@ def __init__(self, kvargs):
def _init_inferstate_cls(self):
pass
+ def _init_att_backend(self):
+ self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(
+ index=0, priority_list=["fa3"]
+ )(model=self)
+ self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(
+ index=0, priority_list=["fa3"]
+ )(model=self)
+
def _init_config(self):
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
all_config = json.load(json_file)
diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py
index 1693bcb964..d7256af2b2 100644
--- a/lightllm/models/neo_chat_moe/infer_struct.py
+++ b/lightllm/models/neo_chat_moe/infer_struct.py
@@ -20,9 +20,18 @@ def __init__(self):
def init_some_extra_state(self, model: LlamaTpPartModel):
LlamaInferStateInfo.init_some_extra_state(self, model)
if self.is_prefill:
+ bsz = self.b_q_seq_len.shape[0]
self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda(
non_blocking=True
)
+ # Pre-allocate to -1 so the "no images anywhere" fast path in
+ # get_neo_position (which skips the triton kernel entirely) still
+ # yields a valid per-batch tensor. When the kernel runs, it
+ # overwrites every batch slot unconditionally with its computed
+ # value (-1 if no image lands in that batch's q window).
+ self.b_max_image_q_idx = torch.full(
+ (bsz,), -1, dtype=torch.int32, device=self.b_q_seq_len.device
+ )
self.position_ids = self.get_neo_position(self.multimodal_params)
else:
b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])]
@@ -97,5 +106,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor:
b_q_seq_len=self.b_q_seq_len,
b_start_loc=self.b_q_start_loc,
b_image_token_tag=self.b_image_token_tag,
+ b_max_image_q_idx=self.b_max_image_q_idx,
)
return position_ids
diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
index 4c4d8a22ab..4fe1d2e9d1 100644
--- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
@@ -1,3 +1,4 @@
+import os
import torch
from functools import partial
from typing import Tuple
@@ -12,6 +13,8 @@
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.common.basemodel.attention.base_att import AttControl
+_USE_TRITON_PREFILL = os.environ.get("LIGHTLLM_NEO_PREFILL_TRITON_BACKEND", "0").strip().lower() in ("1", "true")
+
class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer):
def __init__(self, data_type, network_config):
@@ -82,23 +85,40 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC
def _context_attention_kernel(
self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
) -> torch.Tensor:
- o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
- kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
- context_attention_fwd_neo(
- q.view(-1, self.tp_q_head_num_, self.head_dim_),
- kv[:, 0 : self.tp_k_head_num_, :],
- kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
- o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
- infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
- infer_state.b_req_idx,
- infer_state.b_q_start_loc,
- infer_state.b_seq_len,
- infer_state.b_ready_cache_len,
- infer_state.max_q_seq_len,
- infer_state.req_manager.req_to_token_indexs,
- infer_state.b_image_token_tag,
+ if _USE_TRITON_PREFILL:
+ o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
+ kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
+ context_attention_fwd_neo(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ kv[:, 0 : self.tp_k_head_num_, :],
+ kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
+ o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
+ infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
+ infer_state.b_req_idx,
+ infer_state.b_q_start_loc,
+ infer_state.b_seq_len,
+ infer_state.b_ready_cache_len,
+ infer_state.max_q_seq_len,
+ infer_state.req_manager.req_to_token_indexs,
+ infer_state.b_image_token_tag,
+ )
+ return o_tensor
+
+ _q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
+ _k, _v = infer_state.mem_manager.get_att_input_params(layer_index=self.layer_num_)
+
+ att_control = AttControl()
+ att_control.image_token_tag = getattr(infer_state, "b_image_token_tag", None)
+ att_control.max_image_q_idx = getattr(infer_state, "b_max_image_q_idx", None)
+
+ o_tensor = infer_state.prefill_att_state.prefill_att(
+ q=_q,
+ k=_k,
+ v=_v,
+ att_control=att_control,
+ alloc_func=self.alloc_tensor,
)
- return o_tensor
+ return o_tensor.view(q.shape)
def _token_attention_kernel(
self,
diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py
index dba8965512..7ae3cd27c0 100644
--- a/lightllm/models/neo_chat_moe/model.py
+++ b/lightllm/models/neo_chat_moe/model.py
@@ -20,6 +20,11 @@
from lightllm.models.neo_chat_moe.layer_weights.pre_and_post_layer_weight import NeoChatMOEPreAndPostLayerWeight
from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer
from lightllm.models.neo_chat_moe.infer_struct import NeoChatInferStateInfo
+from lightllm.common.basemodel.attention import (
+ get_prefill_att_backend_class,
+ get_decode_att_backend_class,
+ BaseAttBackend,
+)
IMG_START_TOKEN = "
"
IMG_END_TOKEN = ""
@@ -182,6 +187,14 @@ def __init__(self, kvargs):
def _init_inferstate_cls(self):
pass
+ def _init_att_backend(self):
+ self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(
+ index=0, priority_list=["fa3"]
+ )(model=self)
+ self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(
+ index=0, priority_list=["fa3"]
+ )(model=self)
+
def _init_config(self):
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
all_config = json.load(json_file)
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
index 1a3d4af73b..223befad62 100644
--- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
+++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
@@ -17,6 +17,7 @@ def _get_neo_position_triton(
b_q_seq_len: torch.Tensor,
b_start_loc: torch.Tensor,
b_image_token_tag: torch.Tensor,
+ b_max_image_q_idx: torch.Tensor,
BLOCK_SIZE: tl.constexpr,
) -> torch.Tensor:
cur_batch = tl.program_id(0)
@@ -25,12 +26,30 @@ def _get_neo_position_triton(
image_num = tl.load(b_image_nums + cur_batch)
image_start_num = tl.load(b_image_start_num + cur_batch)
start_loc = tl.load(b_start_loc + cur_batch)
+
+ # Track per-batch last (batch-local) packed-q index where tag=True is
+ # written. -1 means this batch has no image tokens in its q range (either
+ # no images at all, or every image is entirely in prompt cache / out of
+ # the current q window).
+ max_image_q_idx = -1
+
for i in range(image_num):
local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
image_start_idx = start_loc + local_image_start_idx - cache_len
image_len = tl.load(b_image_len + image_start_num + i)
# image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2)
+
+ # Batch-local index of this image's last token that actually lands in
+ # the current q window (after clipping to [0, q_seq_len)). Matches the
+ # mask applied in the stores below.
+ cand_raw = local_image_start_idx - cache_len + image_len - 1
+ candidate = tl.minimum(cand_raw, q_seq_len - 1)
+ contributes = (cand_raw >= 0) & (local_image_start_idx - cache_len < q_seq_len)
+ max_image_q_idx = tl.where(
+ contributes, tl.maximum(max_image_q_idx, candidate), max_image_q_idx
+ )
+
for j in range(0, image_len, BLOCK_SIZE):
off = j + tl.arange(0, BLOCK_SIZE)
# 目前没考虑视频,所以t 恒为 0
@@ -66,6 +85,8 @@ def _get_neo_position_triton(
& (local_image_start_idx - cache_len + off >= 0),
)
+ tl.store(b_max_image_q_idx + cur_batch, max_image_q_idx)
+
for i in range(image_num):
local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
image_len = tl.load(b_image_len + image_start_num + i)
@@ -96,10 +117,12 @@ def get_neo_position_triton(
b_q_seq_len: torch.Tensor,
b_start_loc: torch.Tensor,
b_image_token_tag: torch.Tensor,
+ b_max_image_q_idx: torch.Tensor,
) -> torch.Tensor:
batch_size = b_q_seq_len.shape[0]
assert batch_size == b_image_nums.shape[0]
+ assert b_max_image_q_idx.shape[0] == batch_size
grid = (batch_size,)
BLOCK_SIZE = 64
_get_neo_position_triton[grid](
@@ -115,6 +138,7 @@ def get_neo_position_triton(
b_q_seq_len=b_q_seq_len,
b_start_loc=b_start_loc,
b_image_token_tag=b_image_token_tag,
+ b_max_image_q_idx=b_max_image_q_idx,
BLOCK_SIZE=BLOCK_SIZE,
)
@@ -136,6 +160,7 @@ def test():
b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda")
b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda")
b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda")
+ b_max_image_q_idx = torch.full((b_q_seq_len.shape[0],), -1, dtype=torch.int32, device="cuda")
get_neo_position_triton(
b_image_start_idx,
b_image_thwd,
@@ -147,10 +172,12 @@ def test():
b_q_seq_len,
b_start_loc,
b_image_token_tag,
+ b_max_image_q_idx,
)
print(b_image_token_tag)
print(position_ids)
+ print(b_max_image_q_idx)
# old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1)
# position_ids = (
From 783851265da952292a066fae8f4dc06f89a574fe Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Mon, 20 Apr 2026 15:03:32 +0000
Subject: [PATCH 29/41] import flash_attn_with_kvcache_neo
---
lightllm/common/basemodel/attention/fa3/fp.py | 32 +++++++++++++++----
1 file changed, 25 insertions(+), 7 deletions(-)
diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py
index 3c99c4d095..e1ddcd8b85 100644
--- a/lightllm/common/basemodel/attention/fa3/fp.py
+++ b/lightllm/common/basemodel/attention/fa3/fp.py
@@ -4,6 +4,7 @@
from typing import Optional, TYPE_CHECKING
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
+from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_neo
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
@@ -93,13 +94,31 @@ def _nomarl_prefill_att(
Lq = q.shape[-1]
sm_scale = 1.0 / (Lq ** 0.5)
- # Optional image-token bidirectional attention (neo_chat*). When None,
- # the kernel path reduces to plain causal fa3 (zero overhead).
- extra_kwargs = {}
+ # neo_chat*: image-token bidirectional attention requires flash_attn_interface
+ # (sgl_kernel's flash_attn_with_kvcache does not support image_token_tag).
if att_control.image_token_tag is not None:
- extra_kwargs["image_token_tag"] = att_control.image_token_tag
- if att_control.max_image_q_idx is not None:
- extra_kwargs["max_image_q_idx"] = att_control.max_image_q_idx
+ extra_kwargs = {"image_token_tag": att_control.image_token_tag}
+ if att_control.max_image_q_idx is not None:
+ extra_kwargs["max_image_q_idx"] = att_control.max_image_q_idx
+ o = flash_attn_with_kvcache_neo(
+ q=q,
+ k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
+ v_cache=v.view(v.shape[0], 1, v.shape[1], v.shape[2]),
+ page_table=self.page_table,
+ cache_seqlens=self.infer_state.b_seq_len,
+ cu_seqlens_q=self.cu_seqlens_q,
+ cu_seqlens_k_new=self.cu_seqlens_k,
+ max_seqlen_q=self.infer_state.max_q_seq_len,
+ softmax_scale=sm_scale,
+ causal=True,
+ window_size=window_size,
+ softcap=0.0,
+ k_descale=k_descale,
+ v_descale=v_descale,
+ return_softmax_lse=False,
+ **extra_kwargs,
+ )
+ return o
o = flash_attn_with_kvcache(
q=q,
@@ -118,7 +137,6 @@ def _nomarl_prefill_att(
v_descale=v_descale,
return_softmax_lse=False,
sinks=sink_weight,
- **extra_kwargs,
)
return o
From b202de03de2d813feb488911a9276fb88c9cfc3e Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Mon, 20 Apr 2026 15:46:55 +0000
Subject: [PATCH 30/41] fix
---
lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
index 223befad62..9756371af2 100644
--- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
+++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
@@ -44,7 +44,7 @@ def _get_neo_position_triton(
# the current q window (after clipping to [0, q_seq_len)). Matches the
# mask applied in the stores below.
cand_raw = local_image_start_idx - cache_len + image_len - 1
- candidate = tl.minimum(cand_raw, q_seq_len - 1)
+ candidate = tl.minimum(cand_raw, q_seq_len - 1).to(tl.int32)
contributes = (cand_raw >= 0) & (local_image_start_idx - cache_len < q_seq_len)
max_image_q_idx = tl.where(
contributes, tl.maximum(max_image_q_idx, candidate), max_image_q_idx
From b31991509298d88f35ba0c727496cb985fb02c3c Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Tue, 21 Apr 2026 06:31:26 +0000
Subject: [PATCH 31/41] reduce useless nblock
---
.../triton_kernel/context_attention_fwd_neo.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
index a745b7c9da..c9795e113a 100644
--- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
+++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
@@ -77,12 +77,15 @@ def _fwd_kernel(
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
- block_end_loc = total_len
-
# absolute q positions in the request
q_pos = prompt_cache_len + offs_m # [M]
q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False)
+ # per-M-block: only scan full K range if this block has image tokens
+ has_image = tl.reduce_or(q_image_token_tag.to(tl.int32), axis=0) > 0
+ causal_end = tl.minimum(prompt_cache_len + block_start_loc + BLOCK_M, total_len)
+ block_end_loc = tl.where(has_image, total_len, causal_end)
+
for start_n in range(0, block_end_loc, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
From 468ea9839926812e89717479e40fbc13a12ec0ed Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Tue, 21 Apr 2026 07:26:34 +0000
Subject: [PATCH 32/41] delete max_image_q_idx and reduce_or in kernel
---
.../common/basemodel/attention/base_att.py | 4 ---
lightllm/common/basemodel/attention/fa3/fp.py | 2 --
.../layer_infer/transformer_layer_infer.py | 1 -
lightllm/models/neo_chat_moe/infer_struct.py | 9 -------
.../layer_infer/transformer_layer_infer.py | 1 -
.../triton_kernel/get_neo_position.py | 25 -----------------
.../att/prefill_att/test_fa3_neo.py | 27 +++----------------
7 files changed, 4 insertions(+), 65 deletions(-)
diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py
index 1ee38a1d03..23a02bba13 100644
--- a/lightllm/common/basemodel/attention/base_att.py
+++ b/lightllm/common/basemodel/attention/base_att.py
@@ -71,10 +71,6 @@ class AttControl:
nsa_decode: bool = False
nsa_decode_dict: Dict = None
image_token_tag: Optional[torch.Tensor] = None
- # max_image_q_idx: int32[batch]. 每个 batch 内最后一个 image token 的 local q
- # 下标;没有 image token 的 batch 填 -1。fa3 kernel 据此决定该 batch 的哪些
- # M-block 需要把 n_block_max 延长到 full seqlen_k,避免无 image 的 batch 翻倍计算。
- max_image_q_idx: Optional[torch.Tensor] = None
@dataclass
diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py
index e1ddcd8b85..8ce12ad335 100644
--- a/lightllm/common/basemodel/attention/fa3/fp.py
+++ b/lightllm/common/basemodel/attention/fa3/fp.py
@@ -98,8 +98,6 @@ def _nomarl_prefill_att(
# (sgl_kernel's flash_attn_with_kvcache does not support image_token_tag).
if att_control.image_token_tag is not None:
extra_kwargs = {"image_token_tag": att_control.image_token_tag}
- if att_control.max_image_q_idx is not None:
- extra_kwargs["max_image_q_idx"] = att_control.max_image_q_idx
o = flash_attn_with_kvcache_neo(
q=q,
k_cache=k.view(k.shape[0], 1, k.shape[1], k.shape[2]),
diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
index 0628411cbf..1f5f223654 100644
--- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
@@ -112,7 +112,6 @@ def _context_attention_kernel(
att_control = AttControl()
att_control.image_token_tag = getattr(infer_state, "b_image_token_tag", None)
- att_control.max_image_q_idx = getattr(infer_state, "b_max_image_q_idx", None)
o_tensor = infer_state.prefill_att_state.prefill_att(
q=_q,
diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py
index d7256af2b2..b8e6483def 100644
--- a/lightllm/models/neo_chat_moe/infer_struct.py
+++ b/lightllm/models/neo_chat_moe/infer_struct.py
@@ -24,14 +24,6 @@ def init_some_extra_state(self, model: LlamaTpPartModel):
self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda(
non_blocking=True
)
- # Pre-allocate to -1 so the "no images anywhere" fast path in
- # get_neo_position (which skips the triton kernel entirely) still
- # yields a valid per-batch tensor. When the kernel runs, it
- # overwrites every batch slot unconditionally with its computed
- # value (-1 if no image lands in that batch's q window).
- self.b_max_image_q_idx = torch.full(
- (bsz,), -1, dtype=torch.int32, device=self.b_q_seq_len.device
- )
self.position_ids = self.get_neo_position(self.multimodal_params)
else:
b_position_delta = [0 for _ in range(self.b_seq_len.shape[0])]
@@ -106,6 +98,5 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor:
b_q_seq_len=self.b_q_seq_len,
b_start_loc=self.b_q_start_loc,
b_image_token_tag=self.b_image_token_tag,
- b_max_image_q_idx=self.b_max_image_q_idx,
)
return position_ids
diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
index 4fe1d2e9d1..eab7b575bf 100644
--- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
@@ -109,7 +109,6 @@ def _context_attention_kernel(
att_control = AttControl()
att_control.image_token_tag = getattr(infer_state, "b_image_token_tag", None)
- att_control.max_image_q_idx = getattr(infer_state, "b_max_image_q_idx", None)
o_tensor = infer_state.prefill_att_state.prefill_att(
q=_q,
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
index 9756371af2..dc57870bf1 100644
--- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
+++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
@@ -17,7 +17,6 @@ def _get_neo_position_triton(
b_q_seq_len: torch.Tensor,
b_start_loc: torch.Tensor,
b_image_token_tag: torch.Tensor,
- b_max_image_q_idx: torch.Tensor,
BLOCK_SIZE: tl.constexpr,
) -> torch.Tensor:
cur_batch = tl.program_id(0)
@@ -27,12 +26,6 @@ def _get_neo_position_triton(
image_start_num = tl.load(b_image_start_num + cur_batch)
start_loc = tl.load(b_start_loc + cur_batch)
- # Track per-batch last (batch-local) packed-q index where tag=True is
- # written. -1 means this batch has no image tokens in its q range (either
- # no images at all, or every image is entirely in prompt cache / out of
- # the current q window).
- max_image_q_idx = -1
-
for i in range(image_num):
local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
image_start_idx = start_loc + local_image_start_idx - cache_len
@@ -40,16 +33,6 @@ def _get_neo_position_triton(
# image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2)
- # Batch-local index of this image's last token that actually lands in
- # the current q window (after clipping to [0, q_seq_len)). Matches the
- # mask applied in the stores below.
- cand_raw = local_image_start_idx - cache_len + image_len - 1
- candidate = tl.minimum(cand_raw, q_seq_len - 1).to(tl.int32)
- contributes = (cand_raw >= 0) & (local_image_start_idx - cache_len < q_seq_len)
- max_image_q_idx = tl.where(
- contributes, tl.maximum(max_image_q_idx, candidate), max_image_q_idx
- )
-
for j in range(0, image_len, BLOCK_SIZE):
off = j + tl.arange(0, BLOCK_SIZE)
# 目前没考虑视频,所以t 恒为 0
@@ -85,8 +68,6 @@ def _get_neo_position_triton(
& (local_image_start_idx - cache_len + off >= 0),
)
- tl.store(b_max_image_q_idx + cur_batch, max_image_q_idx)
-
for i in range(image_num):
local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
image_len = tl.load(b_image_len + image_start_num + i)
@@ -117,12 +98,10 @@ def get_neo_position_triton(
b_q_seq_len: torch.Tensor,
b_start_loc: torch.Tensor,
b_image_token_tag: torch.Tensor,
- b_max_image_q_idx: torch.Tensor,
) -> torch.Tensor:
batch_size = b_q_seq_len.shape[0]
assert batch_size == b_image_nums.shape[0]
- assert b_max_image_q_idx.shape[0] == batch_size
grid = (batch_size,)
BLOCK_SIZE = 64
_get_neo_position_triton[grid](
@@ -138,7 +117,6 @@ def get_neo_position_triton(
b_q_seq_len=b_q_seq_len,
b_start_loc=b_start_loc,
b_image_token_tag=b_image_token_tag,
- b_max_image_q_idx=b_max_image_q_idx,
BLOCK_SIZE=BLOCK_SIZE,
)
@@ -160,7 +138,6 @@ def test():
b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda")
b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda")
b_start_loc = torch.tensor([0, 7], dtype=torch.int32, device="cuda")
- b_max_image_q_idx = torch.full((b_q_seq_len.shape[0],), -1, dtype=torch.int32, device="cuda")
get_neo_position_triton(
b_image_start_idx,
b_image_thwd,
@@ -172,12 +149,10 @@ def test():
b_q_seq_len,
b_start_loc,
b_image_token_tag,
- b_max_image_q_idx,
)
print(b_image_token_tag)
print(position_ids)
- print(b_max_image_q_idx)
# old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1)
# position_ids = (
diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py
index 79b1e8e343..9a635f0445 100644
--- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py
+++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py
@@ -182,20 +182,6 @@ def _build_inputs(
b_seq_len = seq_lens.to(torch.int32)
b_prompt_cache_len = prompt_cache_lens.to(torch.int32)
- # Per-batch last image-token index in *batch-local* packed-q coordinates.
- # Matches NeoChatInferStateInfo._compute_b_max_image_q_idx semantics:
- # shape int32[batch]; value == -1 means that batch has no image tokens.
- # Computed on CPU (b_image_token_tag is still on CPU here) so no D2H sync
- # — keeps the eventual flash_attn_with_kvcache call CUDA-graph-safe.
- b_max_image_q_idx_cpu = torch.full((batch,), -1, dtype=torch.int32)
- for b in range(batch):
- start = int(b_q_start_loc[b].item())
- length = int(q_seq_lens[b].item())
- seg = b_image_token_tag[start : start + length]
- idx = torch.nonzero(seg, as_tuple=False)
- if idx.numel() > 0:
- b_max_image_q_idx_cpu[b] = int(idx[-1, 0].item())
-
q = torch.randn((sum_q, Hq, D), dtype=dtype, device=device)
k = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device)
v = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device)
@@ -212,7 +198,6 @@ def _build_inputs(
max_q_seq_len_in_batch=max_q_seq_len_in_batch,
req_to_token_indexs=req_to_token_indexs.to(device),
b_image_token_tag=b_image_token_tag.to(device),
- b_max_image_q_idx=b_max_image_q_idx_cpu.to(device),
q_seq_lens=q_seq_lens,
prompt_cache_lens=prompt_cache_lens,
)
@@ -266,14 +251,9 @@ def _fa3_prefill_with_image_tag(inputs: dict) -> torch.Tensor:
# image-token bidirectional attention. Packed like q (shape [sum_q],
# bool). Rows where the tag is True are allowed to attend to every
# real key in the request (not just the causal prefix).
- #
- # b_max_image_q_idx is int32[batch]: per-batch last image-token index
- # in batch-local coordinates, -1 if no image in that batch. Lets the
- # fa3 kernel skip n_block_max extension for text-only requests, which
- # is critical for mixed-modality batches. Pre-computed in
- # _build_inputs (host-side) so this call is CUDA-graph safe.
+ # The kernel uses warp OR reduce to detect image tokens per M-block
+ # and extends n_block_max for full attention automatically.
image_token_tag=inputs["b_image_token_tag"],
- max_image_q_idx=inputs["b_max_image_q_idx"],
)
return o
@@ -505,7 +485,8 @@ def triton_run():
dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.bfloat16, seed=0, max_q_seq_len=128, max_prompt_cache_len=256),
dict(batch=8, Hq=16, Hk=4, D=128, dtype=torch.bfloat16, seed=1, max_q_seq_len=256, max_prompt_cache_len=512),
dict(batch=16, Hq=28, Hk=4, D=128, dtype=torch.bfloat16, seed=2, max_q_seq_len=128, max_prompt_cache_len=256),
- dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.float16, seed=3, max_q_seq_len=256, max_prompt_cache_len=512),
+ # FP16 case disabled: current build compiled without FP16 support
+ # dict(batch=4, Hq=8, Hk=2, D=128, dtype=torch.float16, seed=3, max_q_seq_len=256, max_prompt_cache_len=512),
]
print("=" * 100)
From 77dc66fa71122710b85afd2cb94c96dc26046c4e Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Tue, 21 Apr 2026 08:22:01 +0000
Subject: [PATCH 33/41] use triton when import fa3 failed
---
lightllm/common/basemodel/attention/fa3/fp.py | 15 ++++++++++++++-
.../layer_infer/transformer_layer_infer.py | 11 +++++++++++
.../layer_infer/transformer_layer_infer.py | 11 +++++++++++
3 files changed, 36 insertions(+), 1 deletion(-)
diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py
index 8ce12ad335..6a0ecfc1af 100644
--- a/lightllm/common/basemodel/attention/fa3/fp.py
+++ b/lightllm/common/basemodel/attention/fa3/fp.py
@@ -4,7 +4,14 @@
from typing import Optional, TYPE_CHECKING
from lightllm.utils.dist_utils import get_current_device_id
from lightllm.utils.sgl_utils import flash_attn_with_kvcache
-from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_neo
+
+try:
+ from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_neo
+
+ HAS_FLASH_ATTN_INTERFACE = True
+except ImportError:
+ flash_attn_with_kvcache_neo = None
+ HAS_FLASH_ATTN_INTERFACE = False
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
@@ -97,6 +104,12 @@ def _nomarl_prefill_att(
# neo_chat*: image-token bidirectional attention requires flash_attn_interface
# (sgl_kernel's flash_attn_with_kvcache does not support image_token_tag).
if att_control.image_token_tag is not None:
+ if not HAS_FLASH_ATTN_INTERFACE:
+ raise ImportError(
+ "flash_attn_interface (fa3-neo) is required for image_token_tag bidirectional "
+ "attention. Install it or set LIGHTLLM_NEO_PREFILL_TRITON_BACKEND=1 to use the "
+ "triton fallback."
+ )
extra_kwargs = {"image_token_tag": att_control.image_token_tag}
o = flash_attn_with_kvcache_neo(
q=q,
diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
index 1f5f223654..a820ddaaf6 100644
--- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
@@ -12,8 +12,19 @@
import torch.distributed as dist
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.common.basemodel.attention.base_att import AttControl
+from lightllm.common.basemodel.attention.fa3.fp import HAS_FLASH_ATTN_INTERFACE
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
_USE_TRITON_PREFILL = os.environ.get("LIGHTLLM_NEO_PREFILL_TRITON_BACKEND", "0").strip().lower() in ("1", "true")
+if not _USE_TRITON_PREFILL and not HAS_FLASH_ATTN_INTERFACE:
+ logger.warning(
+ "flash_attn_interface (fa3-neo) is not installed; falling back to triton prefill backend "
+ "for neo_chat. Install fa3-neo or set LIGHTLLM_NEO_PREFILL_TRITON_BACKEND=1 to silence "
+ "this warning."
+ )
+ _USE_TRITON_PREFILL = True
class NeoChatTransformerLayerInfer(Qwen3TransformerLayerInfer):
diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
index eab7b575bf..66a320fe02 100644
--- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
@@ -12,8 +12,19 @@
import torch.distributed as dist
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.common.basemodel.attention.base_att import AttControl
+from lightllm.common.basemodel.attention.fa3.fp import HAS_FLASH_ATTN_INTERFACE
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
_USE_TRITON_PREFILL = os.environ.get("LIGHTLLM_NEO_PREFILL_TRITON_BACKEND", "0").strip().lower() in ("1", "true")
+if not HAS_FLASH_ATTN_INTERFACE:
+ logger.warning(
+ "flash_attn_interface (fa3-neo) is not installed; falling back to triton prefill backend "
+ "for neo_chat_moe. Install fa3-neo or set LIGHTLLM_NEO_PREFILL_TRITON_BACKEND=1 to silence "
+ "this warning."
+ )
+ _USE_TRITON_PREFILL = True
class NeoChatMOETransformerLayerInfer(Qwen3MOETransformerLayerInfer):
From 4959f4e86383480c3ce60c245ec41e697492b95e Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Tue, 21 Apr 2026 08:35:49 +0000
Subject: [PATCH 34/41] verify fa3 image_token_tag
---
lightllm/common/basemodel/attention/fa3/fp.py | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/lightllm/common/basemodel/attention/fa3/fp.py b/lightllm/common/basemodel/attention/fa3/fp.py
index 6a0ecfc1af..2e63a8dc2e 100644
--- a/lightllm/common/basemodel/attention/fa3/fp.py
+++ b/lightllm/common/basemodel/attention/fa3/fp.py
@@ -7,6 +7,12 @@
try:
from flash_attn_interface import flash_attn_with_kvcache as flash_attn_with_kvcache_neo
+ import inspect
+
+ # Verify this is the neo-patched FA3 build (with image_token_tag support),
+ _sig = inspect.signature(flash_attn_with_kvcache_neo)
+ if "image_token_tag" not in _sig.parameters:
+ raise ImportError("flash_attn_interface found but missing image_token_tag support (need neo build)")
HAS_FLASH_ATTN_INTERFACE = True
except ImportError:
From 77f224a06d5dbf4a2ec0329841c906c96742db35 Mon Sep 17 00:00:00 2001
From: Charles2530 <2569337619@qq.com>
Date: Wed, 22 Apr 2026 22:52:02 +0800
Subject: [PATCH 35/41] add t2i/it2i thinking
---
lightllm/server/api_openai.py | 37 ++++++++++++++++++++---------------
1 file changed, 21 insertions(+), 16 deletions(-)
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index 55112d7a4b..0f741f6c79 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -308,7 +308,6 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
finish_reason = finish_reason_dict[sub_req_id]
text = "".join(final_output_dict[sub_req_id])
- full_text = text
# Handle reasoning content
reasoning_text = None
@@ -333,14 +332,12 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
tool_calls = None
tool_choice = request.tool_choice
tools = request.tools
- if tool_choice != "none" and any([i in full_text for i in TOOLS_TAG_LIST]):
- if finish_reason == "stop":
- finish_reason = "tool_calls"
+ if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
try:
# 为 tool_call_parser 提供默认值
tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3"
parser = FunctionCallParser(tools, tool_parser)
- full_normal_text, call_info_list = parser.parse_non_stream(full_text)
+ text, call_info_list = parser.parse_non_stream(text)
tool_calls = []
history_tool_calls_cnt = _get_history_tool_calls_cnt(request)
for call_info in call_info_list:
@@ -358,8 +355,8 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
HTTPStatus.BAD_REQUEST,
"Failed to parse fc related info to json format!",
)
- if finish_reason == "tool_calls":
- text = ""
+ if tool_calls and finish_reason == "stop":
+ finish_reason = "tool_calls"
chat_message = ChatMessage(
role="assistant",
content=text if text else "",
@@ -610,13 +607,17 @@ def _normalize_image_b64_for_multimodal(image: Union[str, bytes]) -> str:
return image
-def _apply_image_generation_stop(chat_request: ChatCompletionRequest, image_start_tag: str) -> None:
+def _apply_image_generation_stop(
+ chat_request: ChatCompletionRequest, image_start_tag: str, image_only: bool = False
+) -> None:
stop = chat_request.stop or []
if isinstance(stop, str):
stop = [stop]
stop = list(stop)
if image_start_tag not in stop:
stop.append(image_start_tag)
+ if chat_request.chat_template_kwargs.get("enable_thinking", False) and image_only:
+ stop.append("") # TODO: from model config
chat_request.stop = stop
@@ -683,8 +684,9 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
image_start_tag = g_objs.httpserver_manager.tokenizer.image_start_tag
image_tag = g_objs.httpserver_manager.tokenizer.image_tag
+ image_only = request.modalities == ["image"]
- _apply_image_generation_stop(chat_request, image_start_tag)
+ _apply_image_generation_stop(chat_request, image_start_tag, image_only=image_only)
created_time = int(time.time())
@@ -694,7 +696,10 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
x2i_params = X2IParams()
x2i_params.init_from_image_config(request.image_config)
- if request.modalities == ["image"]:
+ enable_thinking = request.chat_template_kwargs.get("enable_thinking", False)
+ print(f"x2i_params: {x2i_params} {image_only} {enable_thinking}", flush=True)
+
+ if image_only and not enable_thinking:
return await _chat_completion_image_only(request, raw_request, prompt, multimodal_params, x2i_params)
if not request.stream:
@@ -706,7 +711,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
completion_tokens = 0
finish_reason: Optional[str] = "stop"
group_request_id = None
- max_image_gen_num = 15 # TODO: make this configurable
+ max_image_gen_num = 15 if not image_only else 1 # TODO: make this configurable
while max_image_gen_num > 0:
max_image_gen_num -= 1
@@ -730,7 +735,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
full_text += output_chunk
- if need_call_x2i:
+ if need_call_x2i or image_only:
prompt += output_chunk
images = await g_objs.httpserver_manager.generate_image(
prompt, x2i_params, multimodal_params.clone(), request=raw_request, input_image_num=input_image_num
@@ -741,7 +746,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
response_images.extend(_message_contents_from_raw_images(images, request.image_config.image_type))
for image in images:
prompt += image_tag
- full_text += image_tag
+ full_text += image_tag if not image_only else ""
multimodal_params.add_image({"type": "base64", "data": _normalize_image_b64_for_multimodal(image)})
else:
break
@@ -797,7 +802,7 @@ async def stream_result() -> AsyncGenerator[bytes, None]:
completion_tokens = 0
finish_reason = None
group_request_id = None
- max_image_gen_num = 15 # TODO: make this configurable
+ max_image_gen_num = 15 if not image_only else 1 # TODO: make this configurable
while max_image_gen_num > 0:
max_image_gen_num -= 1
@@ -857,11 +862,11 @@ async def stream_result() -> AsyncGenerator[bytes, None]:
)
yield ("data: " + json.dumps(stream_resp.model_dump(), ensure_ascii=False) + "\n\n").encode("utf-8")
- if need_call_x2i:
+ if need_call_x2i or image_only:
prompt += output_chunk
images = await g_objs.httpserver_manager.generate_image(
- prompt, x2i_params, multimodal_params.clone(), request=raw_request
+ prompt, x2i_params, multimodal_params.clone(), request=raw_request, input_image_num=input_image_num
)
if images is None or len(images) == 0:
From 3e841651062f273f9c7c8b682b2d6c0e47d5383b Mon Sep 17 00:00:00 2001
From: Charles2530 <2569337619@qq.com>
Date: Wed, 22 Apr 2026 22:52:27 +0800
Subject: [PATCH 36/41] fix build prompt for tool call
---
lightllm/server/build_prompt.py | 34 ++++++++++++++++++++++-----------
1 file changed, 23 insertions(+), 11 deletions(-)
diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py
index a33e35e815..21c1cde678 100644
--- a/lightllm/server/build_prompt.py
+++ b/lightllm/server/build_prompt.py
@@ -45,10 +45,30 @@ def init_tokenizer(args):
return
+def _normalize_tool_call_arguments(messages: list) -> None:
+ # Convert tool_calls function.arguments from JSON string to dict for Jinja template compatibility
+ # Qwen35's chat template expects arguments to be a dict (uses |items filter)
+ # but OpenAI format sends arguments as a JSON string
+ for msg in messages:
+ tool_calls = msg.get("tool_calls")
+ if tool_calls and isinstance(tool_calls, list):
+ for tool_call in tool_calls:
+ func = tool_call.get("function")
+ if func and isinstance(func, dict):
+ args = func.get("arguments")
+ if isinstance(args, str) and args:
+ try:
+ func["arguments"] = json.loads(args)
+ except (json.JSONDecodeError, TypeError):
+ pass
+
+
async def build_prompt(request, tools) -> str:
global tokenizer
# pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别
messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages]
+ _normalize_tool_call_arguments(messages)
+
kwargs = {"conversation": messages}
if request.character_settings:
kwargs["character_settings"] = request.character_settings
@@ -60,15 +80,7 @@ async def build_prompt(request, tools) -> str:
try:
input_str = tokenizer.apply_chat_template(**kwargs, tokenize=False, add_generation_prompt=True, tools=tools)
- except:
- # This except branch will be triggered when the chosen model
- # has a different tools input format that is not compatiable
- # with openAI's apply_chat_template tool_call format, like Mistral.
- tools = [t if "function" in t else {"function": t} for t in tools]
- input_str = tokenizer.apply_chat_template(
- **kwargs,
- tokenize=True,
- add_generation_prompt=True,
- tools=tools,
- )
+ except BaseException as e:
+ logger.error(f"Failed to build prompt: {e}")
+ raise e
return input_str
From b2f456510d8b0e3ace921b428b37cfe6c1d0098e Mon Sep 17 00:00:00 2001
From: Charles2530 <2569337619@qq.com>
Date: Wed, 22 Apr 2026 23:00:42 +0800
Subject: [PATCH 37/41] dynamic_resolution, height/weight, seed, image_size
---
lightllm/server/api_models.py | 93 +++++++++++--------
lightllm/server/core/objs/x2i_params.py | 11 ++-
lightllm/server/function_call_parser.py | 5 +-
lightllm/server/httpserver/manager.py | 3 +-
.../server/x2i_server/lightx2v/adapter.py | 17 ++--
lightllm/server/x2i_server/manager.py | 2 +-
6 files changed, 78 insertions(+), 53 deletions(-)
diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py
index c2fce7dd95..d6ceebfea8 100644
--- a/lightllm/server/api_models.py
+++ b/lightllm/server/api_models.py
@@ -121,7 +121,7 @@ class CompletionRequest(BaseModel):
prompt: Union[str, List[str], List[int], List[List[int]]]
suffix: Optional[str] = None
max_tokens: Optional[int] = Field(
- default=16384, deprecated="max_tokens is deprecated, please use max_completion_tokens instead"
+ default=256000, deprecated="max_tokens is deprecated, please use max_completion_tokens instead"
)
max_completion_tokens: Optional[int] = None
temperature: Optional[float] = 1.0
@@ -197,7 +197,7 @@ class ChatCompletionRequest(BaseModel):
stream_options: Optional[StreamOptions] = None
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(
- default=16384, deprecated="max_tokens is deprecated, please use max_completion_tokens instead"
+ default=256000, deprecated="max_tokens is deprecated, please use max_completion_tokens instead"
)
max_completion_tokens: Optional[int] = None
presence_penalty: Optional[float] = 0.0
@@ -389,7 +389,7 @@ def ensure_id_is_str(cls, v):
"21:9",
]
-ImageSize: TypeAlias = Literal["0.5K", "1K", "2K", "4K"]
+ImageSize: TypeAlias = Literal["0.5K", "1K", "1.5K", "2K", "4K"]
Modality: TypeAlias = Literal["text", "image", "audio"]
@@ -398,8 +398,10 @@ def ensure_id_is_str(cls, v):
class ImageConfig(BaseModel):
aspect_ratio: AspectRatio = "1:1"
- image_size: ImageSize = "1K"
+ image_size: ImageSize = "1.5K"
image_type: ImageType = "jpeg"
+ height: Optional[int] = -1
+ width: Optional[int] = -1
# X2I / diffusion sampling (optional; server defaults apply when omitted)
steps: Optional[int] = None
guidance_scale: Optional[float] = None
@@ -407,41 +409,50 @@ class ImageConfig(BaseModel):
seed: Optional[int] = None
num_images: Optional[int] = None
cfg_norm: Optional[Literal["none", "cfg_zero_star", "global", "text_channel", "channel"]] = None
-
- # Mapping to actual resolutions (base resolution for 1K)
+ dynamic_resolution: Optional[bool] = True
_aspect_ratio_to_resolution: ClassVar[dict] = {
- "1:1": (1024, 1024),
- "2:3": (832, 1248),
- "3:2": (1248, 832),
- "3:4": (864, 1184),
- "4:3": (1184, 864),
- "4:5": (896, 1152),
- "5:4": (1152, 896),
- "9:16": (768, 1344),
- "16:9": (1920, 1080),
- "21:9": (1536, 672),
- }
-
- _size_multiplier: ClassVar[dict] = {
- "0.5K": 0.5,
- "1K": 1.0,
- "2K": 2.0,
- "4K": 4.0,
+ "1:1": {"1K": (1024, 1024), "1.5K": (1536, 1536), "2K": (2048, 2048)},
+ "16:9": {"1.5K": (2048, 1152), "2K": (2720, 1536)},
+ "9:16": {"1.5K": (1152, 2048), "2K": (1536, 2720)},
+ "3:2": {"1.5K": (1888, 1248), "2K": (2496, 1664)},
+ "2:3": {"1.5K": (1248, 1888), "2K": (1664, 2496)},
+ "4:3": {"1.5K": (1760, 1312), "2K": (2368, 1760)},
+ "3:4": {"1.5K": (1312, 1760), "2K": (1760, 2368)},
+ "1:2": {"1.5K": (1088, 2144), "2K": (1440, 2880)},
+ "2:1": {"1.5K": (2144, 1088), "2K": (2880, 1440)},
+ "1:3": {"1.5K": (864, 2592), "2K": (1152, 3456)},
+ "3:1": {"1.5K": (2592, 864), "2K": (3456, 1152)},
}
+ _size_set: ClassVar[set[str]] = {"1.5K", "2K"}
- @field_validator("aspect_ratio")
+ @field_validator("image_size", mode="before")
@classmethod
- def validate_aspect_ratio(cls, v):
- if v not in cls._aspect_ratio_to_resolution:
- raise ValueError(f"Unsupported aspect ratio: {v}")
+ def normalize_image_size(cls, v):
+ if isinstance(v, str):
+ return v.strip().upper()
return v
- @field_validator("image_size")
- @classmethod
- def validate_image_size(cls, v):
- if v not in cls._size_multiplier:
- raise ValueError(f"Unsupported image size: {v}")
- return v
+ @model_validator(mode="after")
+ def validate_resolution_config(self):
+ has_custom_height = self.height is not None and self.height > 0
+ has_custom_width = self.width is not None and self.width > 0
+ has_any_custom = (self.height is not None and self.height != -1) or (
+ self.width is not None and self.width != -1
+ )
+
+ # If custom resolution is provided, require both height/width and both must be positive.
+ if has_any_custom:
+ if not has_custom_height or not has_custom_width:
+ raise ValueError("height and width must both be provided as positive integers")
+ self.dynamic_resolution = False
+ return self
+
+ # Otherwise, validate ratio and logical image size.
+ if self.aspect_ratio not in self._aspect_ratio_to_resolution:
+ raise ValueError(f"Unsupported aspect ratio: {self.aspect_ratio}")
+ if self.image_size not in self._size_set:
+ raise ValueError(f"Unsupported image size: {self.image_size}")
+ return self
@field_validator("image_type")
@classmethod
@@ -452,16 +463,16 @@ def validate_image_type(cls, v):
def get_resolution(self):
"""Return scaled resolution (width, height)"""
- base = self._aspect_ratio_to_resolution[self.aspect_ratio]
- if base is None:
- return None # extended ratios don't have fixed base
-
- scale = self._size_multiplier[self.image_size]
- w, h = base
- w, h = int(w * scale), int(h * scale)
from lightllm.models.neo_chat_moe.vision_process import smart_resize
- h, w = smart_resize(h, w, factor=32, min_pixels=512 * 512, max_pixels=2048 * 2048)
+ print(f"self.height: {self.height}, self.width: {self.width}", flush=True)
+ if self.height > -1 and self.width > -1:
+ w, h = self.width, self.height
+ else:
+ base = self._aspect_ratio_to_resolution[self.aspect_ratio][self.image_size]
+ w, h = base
+
+ h, w = smart_resize(h, w, factor=32, min_pixels=1024 * 1024, max_pixels=2048 * 2048)
return w, h
diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py
index 281e5f66d2..cac7b9abdd 100644
--- a/lightllm/server/core/objs/x2i_params.py
+++ b/lightllm/server/core/objs/x2i_params.py
@@ -54,9 +54,10 @@ class X2IParams(ctypes.Structure):
_image_guidance_scale: float = 1.0
_seed: int = 42
_num_images: int = 1
- _cfg_norm: CfgNormType = CfgNormType.NONE
- _cfg_interval: float = (-1, 2)
+ _cfg_norm: CfgNormType = CfgNormType.GLOBAL
+ _cfg_interval: float = (0, 1)
_timestep_shift: float = 3.0
+ _dynamic_resolution: bool = True
def init(self, **kwargs):
def _get(key, default):
@@ -76,9 +77,11 @@ def _get(key, default):
self.past_kvcache = PastKVCachePageList()
self.past_kvcache_text = PastKVCachePageList()
self.past_kvcache_img = PastKVCachePageList()
+ self.dynamic_resolution = _get("dynamic_resolution", X2IParams._dynamic_resolution)
self.total_prompt_tokens = 0
self.request_id = 0
self.has_updated_hw = False
+ self.first_image = True
def init_from_image_config(self, image_config: Any) -> None:
"""从 HTTP `image_config`(api_models.ImageConfig)填充,与 `init(**kwargs)` 共用默认值逻辑。"""
@@ -98,6 +101,8 @@ def init_from_image_config(self, image_config: Any) -> None:
kwargs["seed"] = image_config.seed
if image_config.num_images is not None:
kwargs["num_images"] = image_config.num_images
+ if image_config.dynamic_resolution is not None:
+ kwargs["dynamic_resolution"] = image_config.dynamic_resolution
if image_config.cfg_norm is not None:
for e in CfgNormType:
if e.as_str() == image_config.cfg_norm:
@@ -106,6 +111,8 @@ def init_from_image_config(self, image_config: Any) -> None:
self.init(**kwargs)
def update_hw(self, width: int, height: int):
+ if not self.dynamic_resolution:
+ return
if self.has_updated_hw:
return
from lightllm.models.neo_chat_moe.vision_process import smart_resize
diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py
index 3a8fddf744..e9700289c0 100644
--- a/lightllm/server/function_call_parser.py
+++ b/lightllm/server/function_call_parser.py
@@ -11,7 +11,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import ast
import json
+import os
import orjson
import logging
import re
@@ -28,6 +30,7 @@
from .api_models import Tool
logger = logging.getLogger(__name__)
+ENABLE_TOOL_NAME_CHECK = os.getenv("LIGHTLLM_ENABLE_TOOL_NAME_CHECK", "False").upper() in ["ON", "TRUE", "1"]
TOOLS_TAG_LIST = [
"<|plugin|>",
@@ -155,7 +158,7 @@ def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
results = []
for act in action:
name = act.get("name")
- if name and name in tool_indices:
+ if name and (not ENABLE_TOOL_NAME_CHECK or name in tool_indices):
results.append(
ToolCallItem(
tool_index=-1, # Caller should update this based on the actual tools array called
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 99ed011cb1..05db8c335e 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -487,7 +487,7 @@ async def generation_wrapper(prompt, sample, multimodal, request):
else:
# call t2i
prompt_condition, prompt_uncondition = self.tokenizer.get_query_for_t2i(prompt, input_image_num)
- # logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}")
+ logger.info(f"generate image with: {prompt_condition}, and {prompt_uncondition}")
(con_gen, uncon_gen) = await asyncio.gather(
*[
generation_wrapper(prompt_condition, sample_params, multimodal_params, request),
@@ -514,6 +514,7 @@ async def generation_wrapper(prompt, sample, multimodal, request):
assert req_status.response is not None
self.req_id_to_x2i_reqs.pop(x2i_req_id, None)
+ generation_params.first_image = False
return req_status.response.images
diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py
index b738b00d7d..95a795f617 100644
--- a/lightllm/server/x2i_server/lightx2v/adapter.py
+++ b/lightllm/server/x2i_server/lightx2v/adapter.py
@@ -55,9 +55,7 @@ def _init_pipeline(self):
support_tasks=["t2i", "i2i"],
)
self.pipe.create_generator(config_json=self.args.x2v_gen_model_config)
- self.pipe.modify_config({
- "load_kv_cache_in_pipeline_for_debug": False,
- "save_result_for_debug": False})
+ self.pipe.modify_config({"load_kv_cache_in_pipeline_for_debug": False, "save_result_for_debug": False})
async def run(self):
while True:
@@ -95,13 +93,16 @@ async def _process(self, param: X2IParams):
timestep_shift=param.timestep_shift,
)
past_kv_cache = self.past_kv_cache_client.get_kv_cache_for_x2i(
- param.past_kvcache.get_all(), param.past_kvcache.token_len)
+ param.past_kvcache.get_all(), param.past_kvcache.token_len
+ )
past_kv_cache_text = self.past_kv_cache_client.get_kv_cache_for_x2i(
- param.past_kvcache_text.get_all(), param.past_kvcache_text.token_len)
+ param.past_kvcache_text.get_all(), param.past_kvcache_text.token_len
+ )
past_kv_cache_img = None
if not is_t2i:
past_kv_cache_img = self.past_kv_cache_client.get_kv_cache_for_x2i(
- param.past_kvcache_img.get_all(), param.past_kvcache_img.token_len)
+ param.past_kvcache_img.get_all(), param.past_kvcache_img.token_len
+ )
dist.barrier() # ensure all workers have got the kv cache before generation starts
@@ -122,8 +123,10 @@ async def _process(self, param: X2IParams):
past_kv_cache_text,
past_kv_cache_img,
)
+ seed = param.seed if param.first_image else None
+ logger.info(f"seed: {seed} {param.seed} first_image: {param.first_image}")
image = self.pipe.generate(
- seed=param.seed + param.past_kvcache.img_len,
+ seed=seed,
save_result_path="",
target_shape=[param.height, param.width],
)
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index aef9b0ab88..c48f283e79 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -108,7 +108,7 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams
for i in range(param.num_images):
self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text)
image = self.gen_pipe.generate(
- seed=param.seed + param.past_kvcache.img_len + i,
+ seed=param.seed if param.first_image else None,
save_result_path="", # 返回base64,不需要指定路径了
target_shape=[param.height, param.width], # Height, Width
)
From c99e4b35e0436f6650431e671674765c4cea6d99 Mon Sep 17 00:00:00 2001
From: Weichao Luo
Date: Thu, 23 Apr 2026 08:14:24 +0000
Subject: [PATCH 38/41] fix lint.
---
.../triton_kernel/kv_cache_offload.py | 32 +-
.../layer_infer/transformer_layer_infer.py | 1 -
lightllm/models/neo_chat/model.py | 12 +-
lightllm/models/neo_chat_moe/infer_struct.py | 1 -
lightllm/models/neo_chat_moe/model.py | 12 +-
lightllm/server/api_cli.py | 3 +-
lightllm/server/api_lightllm.py | 2 +-
lightllm/server/api_openai.py | 6 +-
lightllm/server/api_start.py | 1 +
lightllm/server/core/objs/start_args_type.py | 2 +-
.../core/objs/token_chunck_hash_list.py | 6 +-
.../model_infer/mode_backend/base_backend.py | 4 +-
.../model_infer/mode_backend/past_kv_cache.py | 27 +-
.../naive/configuration_neo_chat.py | 26 +-
.../x2i_server/naive/configuration_neo_vit.py | 40 +-
.../x2i_server/naive/modeling_fm_modules.py | 136 +++---
.../x2i_server/naive/modeling_neo_chat.py | 403 +++++++++++-------
.../x2i_server/naive/modeling_neo_vit.py | 71 +--
.../test_context_attention_fwd_neo.py | 4 +-
.../att/prefill_att/test_fa3_neo.py | 18 +-
.../kv_trans_kernel/test_kv_trans_from_gpu.py | 65 ++-
21 files changed, 481 insertions(+), 391 deletions(-)
diff --git a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
index 01d32e908c..0cff35864c 100644
--- a/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
+++ b/lightllm/common/basemodel/triton_kernel/kv_cache_offload.py
@@ -706,7 +706,6 @@ def load_cpu_kv_to_gpu(
return
-
@triton.jit
def _offload_gpu_kv_to_cpu_for_x2i(
token_indexes_ptr,
@@ -800,8 +799,12 @@ def _offload_gpu_kv_to_cpu_for_x2i(
+ cpu_k_head_index * cpu_scale_stride3
+ head_dim_range[None, :]
)
- tl.store(cpu_scale_ptr, gpu_scale_data, mask=token_scale_mask, cache_modifier=".wt",)
-
+ tl.store(
+ cpu_scale_ptr,
+ gpu_scale_data,
+ mask=token_scale_mask,
+ cache_modifier=".wt",
+ )
for v_head_index in range(gpu_v_head_num):
gpu_v_head_index = v_head_index + gpu_v_start_head_index
@@ -842,8 +845,12 @@ def _offload_gpu_kv_to_cpu_for_x2i(
+ cpu_v_head_index * cpu_scale_stride3
+ head_dim_range[None, :]
)
- tl.store(cpu_scale_ptr, gpu_scale_data, mask=token_scale_mask, cache_modifier=".wt",)
-
+ tl.store(
+ cpu_scale_ptr,
+ gpu_scale_data,
+ mask=token_scale_mask,
+ cache_modifier=".wt",
+ )
@torch.no_grad()
@@ -955,7 +962,7 @@ def offload_gpu_kv_to_cpu_for_x2i(
assert token_block_size == triton.next_power_of_2(token_block_size)
page_num = page_indexes.shape[0]
- grid = (grid_num, )
+ grid = (grid_num,)
num_warps = 4
num_stages = 1
HAS_SCALE = gpu_kv_cache_scale is not None and cpu_kv_cache_scale is not None
@@ -968,14 +975,13 @@ def offload_gpu_kv_to_cpu_for_x2i(
gpu_scale_stride = [0 for _ in range(5)]
cpu_scale_stride = [0 for _ in range(5)]
-
_offload_gpu_kv_to_cpu_for_x2i[grid](
- token_indexes_ptr = token_indexes,
- gpu_kv_cache_ptr = gpu_kv_cache,
- gpu_stride0 = gpu_kv_cache.stride(0),
- gpu_stride1 = gpu_kv_cache.stride(1),
- gpu_stride2 = gpu_kv_cache.stride(2),
- gpu_kv_cache_scale_ptr = gpu_kv_cache_scale,
+ token_indexes_ptr=token_indexes,
+ gpu_kv_cache_ptr=gpu_kv_cache,
+ gpu_stride0=gpu_kv_cache.stride(0),
+ gpu_stride1=gpu_kv_cache.stride(1),
+ gpu_stride2=gpu_kv_cache.stride(2),
+ gpu_kv_cache_scale_ptr=gpu_kv_cache_scale,
gpu_scale_stride0=gpu_scale_stride[0],
gpu_scale_stride1=gpu_scale_stride[1],
gpu_scale_stride2=gpu_scale_stride[2],
diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
index a820ddaaf6..c5c4f3c343 100644
--- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
@@ -97,7 +97,6 @@ def _get_qkv(self, input, infer_state: NeoChatInferStateInfo, layer_weight: NeoC
def _context_attention_kernel(
self, q, kv, infer_state: NeoChatInferStateInfo, layer_weight, out=None
) -> torch.Tensor:
-
if _USE_TRITON_PREFILL:
o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out
diff --git a/lightllm/models/neo_chat/model.py b/lightllm/models/neo_chat/model.py
index e17ad31fa1..80621034ae 100644
--- a/lightllm/models/neo_chat/model.py
+++ b/lightllm/models/neo_chat/model.py
@@ -46,12 +46,12 @@ def _init_inferstate_cls(self):
pass
def _init_att_backend(self):
- self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(
- index=0, priority_list=["fa3"]
- )(model=self)
- self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(
- index=0, priority_list=["fa3"]
- )(model=self)
+ self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])(
+ model=self
+ )
+ self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0, priority_list=["fa3"])(
+ model=self
+ )
def _init_config(self):
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py
index b8e6483def..1693bcb964 100644
--- a/lightllm/models/neo_chat_moe/infer_struct.py
+++ b/lightllm/models/neo_chat_moe/infer_struct.py
@@ -20,7 +20,6 @@ def __init__(self):
def init_some_extra_state(self, model: LlamaTpPartModel):
LlamaInferStateInfo.init_some_extra_state(self, model)
if self.is_prefill:
- bsz = self.b_q_seq_len.shape[0]
self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda(
non_blocking=True
)
diff --git a/lightllm/models/neo_chat_moe/model.py b/lightllm/models/neo_chat_moe/model.py
index 7ae3cd27c0..fb0cb988d2 100644
--- a/lightllm/models/neo_chat_moe/model.py
+++ b/lightllm/models/neo_chat_moe/model.py
@@ -188,12 +188,12 @@ def _init_inferstate_cls(self):
pass
def _init_att_backend(self):
- self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(
- index=0, priority_list=["fa3"]
- )(model=self)
- self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(
- index=0, priority_list=["fa3"]
- )(model=self)
+ self.prefill_att_backend: BaseAttBackend = get_prefill_att_backend_class(index=0, priority_list=["fa3"])(
+ model=self
+ )
+ self.decode_att_backend: BaseAttBackend = get_decode_att_backend_class(index=0, priority_list=["fa3"])(
+ model=self
+ )
def _init_config(self):
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index bc811ff17b..e2aa4f79cf 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -352,7 +352,8 @@ def make_argument_parser() -> argparse.ArgumentParser:
type=str,
choices=["colocate", "separate"],
default="colocate",
- help="Deployment mode for the x2i server. 'colocate' means the x2i server will run on the same gpus as the llm server, ",
+ help="Deployment mode for the x2i server. 'colocate' means the x2i server will "
+ "run on the same gpus as the llm server, ",
)
parser.add_argument(
"--x2i_use_naive_impl",
diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py
index b20670e47c..7963423d8f 100644
--- a/lightllm/server/api_lightllm.py
+++ b/lightllm/server/api_lightllm.py
@@ -165,4 +165,4 @@ async def lightllm_generate_image(request: Request, httpserver_manager: HttpServ
results = await httpserver_manager.generate_image(prompt, generation_params, multimodal_params, request=request)
- return Response(content=json.dumps({"images": results}, ensure_ascii=False).encode("utf-8"))
\ No newline at end of file
+ return Response(content=json.dumps({"images": results}, ensure_ascii=False).encode("utf-8"))
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index 376cb6c397..d137bba162 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -185,9 +185,7 @@ def _get_images_and_audios(request: ChatCompletionRequest):
# Local file path with file:// prefix
file_path = img[7:] # Remove "file://" prefix
with open(file_path, "rb") as f:
- images.append(
- {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")}
- )
+ images.append({"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")})
else:
raise ValueError(
"Unrecognized image input. Supports local path, http url, base64, and PIL.Image."
@@ -222,7 +220,7 @@ def _get_tools(request: ChatCompletionRequest):
tools = [item.function.model_dump() for item in request.tools]
return tools
-
+
def _split_tool_argument_delta(arguments: Optional[str]) -> List[str]:
"""Split a complete JSON argument string into OpenAI-style deltas."""
if not arguments:
diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py
index 3cdc84c68c..5be30526a5 100644
--- a/lightllm/server/api_start.py
+++ b/lightllm/server/api_start.py
@@ -413,6 +413,7 @@ def normal_or_p_d_start(args):
if args.enable_multimodal_x2i:
from .x2i_server.manager import start_x2i_process, setup_devices
+
origin_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
setup_devices(args)
process_manager.start_submodule_processes(
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index 73b8b0060e..84253ab29e 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -187,7 +187,7 @@ class StartArgs:
metric_port: int = field(default=None)
multinode_httpmanager_port: int = field(default=12345)
multi_level_kv_cache_port: int = field(default=None)
-
+
# multi_modal_x2i
enable_multimodal_x2i: bool = field(default=False)
x2i_port: int = field(default=None)
diff --git a/lightllm/server/core/objs/token_chunck_hash_list.py b/lightllm/server/core/objs/token_chunck_hash_list.py
index de43cc4cc6..6927cc3488 100644
--- a/lightllm/server/core/objs/token_chunck_hash_list.py
+++ b/lightllm/server/core/objs/token_chunck_hash_list.py
@@ -91,10 +91,10 @@ def get_all(self):
class PastKVCachePageList(CpuCachePageList):
_pack_ = 4
- _fields_ = CpuCachePageList._fields_ +[
+ _fields_ = CpuCachePageList._fields_ + [
("token_len", ctypes.c_int), # 对应的token数量
("img_tokens", ctypes.c_int),
- ("img_len", ctypes.c_int)
+ ("img_len", ctypes.c_int),
]
def __init__(self, token_len: int = 0):
@@ -107,4 +107,4 @@ def get_compressed_len(self):
return self.token_len - self.img_tokens + self.img_len
def __repr__(self):
- return f"(token_len={self.token_len}, img_tokens={self.img_tokens}, img_len={self.img_len})"
\ No newline at end of file
+ return f"(token_len={self.token_len}, img_tokens={self.img_tokens}, img_len={self.img_len})"
diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py
index df8dfe3316..7a4d851603 100644
--- a/lightllm/server/router/model_infer/mode_backend/base_backend.py
+++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py
@@ -48,6 +48,7 @@
from .multi_level_kv_cache import MultiLevelKvCacheModule
from .past_kv_cache import PastKVCacheModule
+
class ModeBackend:
def __init__(self) -> None:
self.shm_req_manager = ShmReqManager()
@@ -658,7 +659,8 @@ def _get_classed_reqs(
if self.args.enable_multimodal_x2i:
true_finished_reqs = self.past_kv_cache_module.offload_finished_reqs_to_past_kv_cache(
- finished_reqs=true_finished_reqs)
+ finished_reqs=true_finished_reqs
+ )
g_infer_context.filter_reqs(finished_reqs=true_finished_reqs)
g_infer_context.pause_reqs(wait_pause_reqs, is_master_in_dp=self.is_master_in_dp)
diff --git a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
index e6457ded9a..79f9e62a2f 100644
--- a/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
+++ b/lightllm/server/router/model_infer/mode_backend/past_kv_cache.py
@@ -22,9 +22,11 @@ class TransTask:
sync_event: torch.cuda.Event
buffer_index: int
+
class PastKVCacheModule(object):
def __init__(self, backend):
from .base_backend import ModeBackend
+
self.backend: ModeBackend = backend
self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=False)
self.page_index_buffer = torch.empty((1024 * LIGHTLLM_TOKEN_HASH_LIST_SIZE,), dtype=torch.int32, device="cuda")
@@ -32,7 +34,6 @@ def __init__(self, backend):
self.past_kv_cache_task: Deque[TransTask] = deque()
self.sync_task_status_group = create_new_group_for_current_dp("gloo")
-
@lru_cache()
def need_sync_compute_stream(self) -> bool:
"""
@@ -71,8 +72,9 @@ def offload_finished_reqs_to_past_kv_cache(self, finished_reqs: List[InferReq])
if req.past_kv_cache_task_status.is_running():
continue
- assert req.past_kv_cache_task_status.is_not_started(), \
- f"req {req.req_id} has invalid past kv cache task status {req.past_kv_cache_task_status}"
+ assert (
+ req.past_kv_cache_task_status.is_not_started()
+ ), f"req {req.req_id} has invalid past kv cache task status {req.past_kv_cache_task_status}"
if self.need_sync_compute_stream():
g_infer_context.get_overlap_stream().synchronize()
@@ -95,7 +97,9 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]:
start = free_index * LIGHTLLM_TOKEN_HASH_LIST_SIZE
end = start + req.shm_req.past_kv_cache_page_indexes.size
- page_indexes = torch.tensor(req.shm_req.past_kv_cache_page_indexes.get_all(), dtype=torch.int32, device='cpu', pin_memory=True)
+ page_indexes = torch.tensor(
+ req.shm_req.past_kv_cache_page_indexes.get_all(), dtype=torch.int32, device="cpu", pin_memory=True
+ )
num_tokens = req.shm_req.input_len
assert req.cur_kv_len >= num_tokens
@@ -104,13 +108,12 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]:
cuda_page_indexes = self.page_index_buffer[start:end]
cuda_page_indexes.copy_(page_indexes)
- token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0: num_tokens]
+ token_indexes = self.backend.model.req_manager.req_to_token_indexs[req.req_idx, 0:num_tokens]
mem_manager = self.backend.model.mem_manager
-
if hasattr(mem_manager, "scale_buffer") and mem_manager.scale_buffer is not None:
cpu_cache_meta = self.past_kv_cache_client.kv_cache_tensor_meta
- cpu_kv_cache = self.past_kv_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0:cpu_cache_meta.head_dim]
+ cpu_kv_cache = self.past_kv_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0 : cpu_cache_meta.head_dim]
cpu_kv_cache_scale = self.past_kv_cache_client.cpu_kv_cache_tensor[
:, :, :, :, cpu_cache_meta.head_dim :
].view(mem_manager.scale_buffer.dtype)
@@ -137,11 +140,7 @@ def _start_kv_cache_offload(self, req: InferReq) -> Optional[TransTask]:
sync_event.wait(g_infer_context.get_overlap_stream())
# sync_event.synchronize()
req.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.RUNNING
- return TransTask(
- req_obj=req,
- sync_event=sync_event,
- buffer_index=free_index
- )
+ return TransTask(req_obj=req, sync_event=sync_event, buffer_index=free_index)
def update_past_kv_cache_task_states(self):
trans_ok_tasks = []
@@ -157,7 +156,7 @@ def update_past_kv_cache_task_states(self):
dist.all_reduce(ok_tasks_num, op=dist.ReduceOp.MIN, group=self.sync_task_status_group)
if ok_tasks_num.item() > 0:
- finished, unfinished = trans_ok_tasks[:ok_tasks_num.item()], trans_ok_tasks[ok_tasks_num.item():]
+ finished, unfinished = trans_ok_tasks[: ok_tasks_num.item()], trans_ok_tasks[ok_tasks_num.item() :]
self.past_kv_cache_task.extendleft(reversed(unfinished))
for task in finished:
task.req_obj.past_kv_cache_task_status = InferReq._CpuCacheTaskStatus.FINISHED
@@ -171,4 +170,4 @@ def update_past_kv_cache_task_states(self):
shm_req.candetoken_out_len = shm_req.shm_cur_output_len
else:
if len(trans_ok_tasks) > 0:
- self.past_kv_cache_task.extendleft(reversed(trans_ok_tasks))
\ No newline at end of file
+ self.past_kv_cache_task.extendleft(reversed(trans_ok_tasks))
diff --git a/lightllm/server/x2i_server/naive/configuration_neo_chat.py b/lightllm/server/x2i_server/naive/configuration_neo_chat.py
index 44171917a5..356284d7b0 100644
--- a/lightllm/server/x2i_server/naive/configuration_neo_chat.py
+++ b/lightllm/server/x2i_server/naive/configuration_neo_chat.py
@@ -18,7 +18,7 @@ def __init__(self, rope_theta_hw=10000.0, max_position_embeddings_hw=10000, **kw
class NEOChatConfig(PretrainedConfig):
- model_type = 'neo_chat'
+ model_type = "neo_chat"
is_composition = True
def __init__(
@@ -34,13 +34,13 @@ def __init__(
super().__init__(**kwargs)
if vision_config is None:
- vision_config = {'architectures': ['NEOVisionModel']}
- logger.info('vision_config is None. Initializing the NEOVisionConfig with default values.')
+ vision_config = {"architectures": ["NEOVisionModel"]}
+ logger.info("vision_config is None. Initializing the NEOVisionConfig with default values.")
if llm_config is None:
- llm_config = {'architectures': ['Qwen3ForCausalLM']}
- logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
- assert 'architectures' in llm_config, "Should specify architecture in llm_config"
+ llm_config = {"architectures": ["Qwen3ForCausalLM"]}
+ logger.info("llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).")
+ assert "architectures" in llm_config, "Should specify architecture in llm_config"
if isinstance(vision_config, dict):
self.vision_config = NEOVisionConfig(**vision_config)
@@ -66,12 +66,12 @@ def to_dict(self):
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
- output['vision_config'] = self.vision_config.to_dict()
- output['llm_config'] = self.llm_config.to_dict()
- output['model_type'] = self.__class__.model_type
- output['use_backbone_lora'] = self.use_backbone_lora
- output['use_llm_lora'] = self.use_llm_lora
- output['downsample_ratio'] = self.downsample_ratio
- output['template'] = self.template
+ output["vision_config"] = self.vision_config.to_dict()
+ output["llm_config"] = self.llm_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ output["use_backbone_lora"] = self.use_backbone_lora
+ output["use_llm_lora"] = self.use_llm_lora
+ output["downsample_ratio"] = self.downsample_ratio
+ output["template"] = self.template
return output
diff --git a/lightllm/server/x2i_server/naive/configuration_neo_vit.py b/lightllm/server/x2i_server/naive/configuration_neo_vit.py
index 02837fea41..2e0d09356f 100644
--- a/lightllm/server/x2i_server/naive/configuration_neo_vit.py
+++ b/lightllm/server/x2i_server/naive/configuration_neo_vit.py
@@ -9,26 +9,26 @@
class NEOVisionConfig(PretrainedConfig):
- model_type = 'neo_vision'
+ model_type = "neo_vision"
def __init__(
- self,
- num_channels=3,
- patch_size=16,
- hidden_size=1024,
- llm_hidden_size=2048,
- downsample_ratio=0.5,
- rope_theta_vision=10000.0,
- max_position_embeddings_vision=10000,
- min_pixels=65536,
- max_pixels=4194304,
- **kwargs,
+ self,
+ num_channels=3,
+ patch_size=16,
+ hidden_size=1024,
+ llm_hidden_size=2048,
+ downsample_ratio=0.5,
+ rope_theta_vision=10000.0,
+ max_position_embeddings_vision=10000,
+ min_pixels=65536,
+ max_pixels=4194304,
+ **kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
- self.llm_hidden_size = llm_hidden_size,
- self.downsample_ratio = downsample_ratio,
+ self.llm_hidden_size = (llm_hidden_size,)
+ self.downsample_ratio = (downsample_ratio,)
self.rope_theta_vision = rope_theta_vision
self.max_position_embeddings_vision = max_position_embeddings_vision
self.num_channels = num_channels
@@ -37,16 +37,16 @@ def __init__(
self.max_pixels = max_pixels
@classmethod
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
- if 'vision_config' in config_dict:
- config_dict = config_dict['vision_config']
+ if "vision_config" in config_dict:
+ config_dict = config_dict["vision_config"]
- if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
- f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
- return cls.from_dict(config_dict, **kwargs)
\ No newline at end of file
+ return cls.from_dict(config_dict, **kwargs)
diff --git a/lightllm/server/x2i_server/naive/modeling_fm_modules.py b/lightllm/server/x2i_server/naive/modeling_fm_modules.py
index 56654b2580..ff36ddf37a 100644
--- a/lightllm/server/x2i_server/naive/modeling_fm_modules.py
+++ b/lightllm/server/x2i_server/naive/modeling_fm_modules.py
@@ -5,11 +5,14 @@
from functools import lru_cache
from torch.utils.checkpoint import checkpoint
+
+
def modulate(x, shift, scale=None):
if shift is None:
return x * (1 + scale)
return x * (1 + scale) + shift
+
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
@@ -20,6 +23,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return output * self.weight
+
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
@@ -59,8 +63,8 @@ def forward(self, t):
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb
-class ResBlock(nn.Module):
+class ResBlock(nn.Module):
def __init__(self, channels, mlp_ratio=1.0):
super().__init__()
self.channels = channels
@@ -81,6 +85,7 @@ def forward(self, x, y):
h = self.mlp(h)
return x + gate_mlp * h
+
# class FinalLayer(nn.Module):
# def __init__(self, model_channels, out_channels):
@@ -162,8 +167,8 @@ def forward(self, x, y):
# return self.final_layer(x, y)
-class FlowMatchingHead(nn.Module):
+class FlowMatchingHead(nn.Module):
def __init__(self, input_dim, out_dim, dim=1536, layers=12, mlp_ratio=1.0):
super(FlowMatchingHead, self).__init__()
self.net = SimpleMLPAdaLN(input_dim=input_dim, out_dim=out_dim, dim=dim, layers=layers, mlp_ratio=mlp_ratio)
@@ -181,7 +186,7 @@ def forward(self, x, t):
return x
-def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
+def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float = 10000.0, scale=16.0):
# assert H * H == end
# flat_patch_pos = torch.linspace(-1, 1, end) # N = end
x_pos = torch.linspace(0, scale, width)
@@ -189,22 +194,23 @@ def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 100
y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
y_pos = y_pos.reshape(-1)
x_pos = x_pos.reshape(-1)
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
- x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
- y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
+ x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
+ y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
- freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
- freqs_cis = freqs_cis.reshape(height*width, -1)
+ freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
+ freqs_cis = freqs_cis.reshape(height * width, -1)
return freqs_cis
+
class NerfEmbedder(nn.Module):
def __init__(self, in_channels, hidden_size_input, max_freqs):
super().__init__()
self.max_freqs = max_freqs
self.hidden_size_input = hidden_size_input
self.embedder = nn.Sequential(
- nn.Linear(in_channels+max_freqs**2, hidden_size_input, bias=True),
+ nn.Linear(in_channels + max_freqs ** 2, hidden_size_input, bias=True),
)
@lru_cache
@@ -213,7 +219,6 @@ def fetch_pos(self, patch_size, device, dtype):
pos = pos[None, :, :].to(device=device, dtype=dtype)
return pos
-
def forward(self, inputs):
B, P2, C = inputs.shape
patch_size = int(P2 ** 0.5)
@@ -225,6 +230,7 @@ def forward(self, inputs):
inputs = self.embedder(inputs)
return inputs
+
class SimpleMLPAdaLN(nn.Module):
"""
The MLP for Diffusion Loss.
@@ -243,7 +249,7 @@ def __init__(
z_channels,
num_res_blocks,
patch_size,
- grad_checkpointing=False
+ grad_checkpointing=False,
):
super().__init__()
@@ -254,15 +260,17 @@ def __init__(
self.grad_checkpointing = grad_checkpointing
self.patch_size = patch_size
- self.cond_embed = nn.Linear(z_channels, patch_size**2*model_channels)
+ self.cond_embed = nn.Linear(z_channels, patch_size ** 2 * model_channels)
self.input_proj = nn.Linear(in_channels, model_channels)
res_blocks = []
for i in range(num_res_blocks):
- res_blocks.append(ResBlock(
- model_channels,
- ))
+ res_blocks.append(
+ ResBlock(
+ model_channels,
+ )
+ )
self.res_blocks = nn.ModuleList(res_blocks)
self.final_layer = FinalLayer(model_channels, out_channels)
@@ -275,6 +283,7 @@ def _basic_init(module):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
+
self.apply(_basic_init)
# Zero-out adaLN modulation layers
@@ -297,7 +306,7 @@ def forward(self, x, c):
x = self.input_proj(x)
c = self.cond_embed(c)
- y = c.reshape(-1, self.patch_size**2, self.model_channels)
+ y = c.reshape(-1, self.patch_size ** 2, self.model_channels)
for block in self.res_blocks:
x = block(x, y)
@@ -309,6 +318,7 @@ class FinalLayer(nn.Module):
"""
The final layer adopted from DiT.
"""
+
def __init__(self, model_channels, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
@@ -319,6 +329,7 @@ def forward(self, x):
x = self.linear(x)
return x
+
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
@@ -363,7 +374,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
- omega = 1.0 / 10000**omega # (D/2,)
+ omega = 1.0 / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
@@ -374,6 +385,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
+
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
@@ -385,11 +397,11 @@ def interpolate_pos_embed(model_path, pe_key: str = "gen_pos_embed", new_len: in
pos_embed_1d = state_dict[pe_key]
_, ori_len, embed_dim = pos_embed_1d.shape
- ori_size = int(ori_len**0.5)
- new_size = int(new_len**0.5)
+ ori_size = int(ori_len ** 0.5)
+ new_size = int(new_len ** 0.5)
if ori_size != new_size:
- logger.info("Position interpolate from %dx%d to %dx%d" % (ori_size, ori_size, new_size, new_size))
+ # logger.info("Position interpolate from %dx%d to %dx%d" % (ori_size, ori_size, new_size, new_size))
pos_embed_2d = pos_embed_1d.reshape(-1, ori_size, ori_size, embed_dim).permute(0, 3, 1, 2)
pos_embed_2d = torch.nn.functional.interpolate(
pos_embed_2d, size=(new_size, new_size), mode="bicubic", align_corners=False
@@ -399,15 +411,13 @@ def interpolate_pos_embed(model_path, pe_key: str = "gen_pos_embed", new_len: in
torch.save(state_dict, model_path)
+
class PositionEmbedding(nn.Module):
def __init__(self, max_num_patch_per_side, hidden_size):
super().__init__()
self.max_num_patch_per_side = max_num_patch_per_side
self.hidden_size = hidden_size
- self.pos_embed = nn.Parameter(
- torch.zeros(max_num_patch_per_side ** 2, hidden_size),
- requires_grad=False
- )
+ self.pos_embed = nn.Parameter(torch.zeros(max_num_patch_per_side ** 2, hidden_size), requires_grad=False)
self._init_weights()
def _init_weights(self):
@@ -457,37 +467,37 @@ def __init__(self, hidden_dim=4096, out_channels=3):
# self.proj = nn.Linear(hidden_dim, 1024)
# self.act = nn.SiLU()
- self.up_blocks = nn.ModuleList([
- nn.Sequential(
- nn.Upsample(scale_factor=2, mode='nearest'),
- nn.Conv2d(hidden_dim, 512, kernel_size=3, padding=1),
- nn.GroupNorm(32, 512),
- nn.SiLU()
- ),
- nn.Sequential(
- nn.Upsample(scale_factor=2, mode='nearest'),
- nn.Conv2d(512, 256, kernel_size=3, padding=1),
- nn.GroupNorm(32, 256),
- nn.SiLU()
- ),
- nn.Sequential(
- nn.Upsample(scale_factor=2, mode='nearest'),
- nn.Conv2d(256, 64, kernel_size=3, padding=1),
- nn.GroupNorm(32, 64),
- nn.SiLU()
- ),
- nn.Sequential(
- nn.Upsample(scale_factor=2, mode='nearest'),
- nn.Conv2d(64, 32, kernel_size=3, padding=1),
- nn.GroupNorm(16, 32),
- nn.SiLU()
- ),
- nn.Sequential(
- nn.Upsample(scale_factor=2, mode='nearest'),
- nn.Conv2d(32, 16, kernel_size=3, padding=1),
- nn.SiLU()
- )
- ])
+ self.up_blocks = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="nearest"),
+ nn.Conv2d(hidden_dim, 512, kernel_size=3, padding=1),
+ nn.GroupNorm(32, 512),
+ nn.SiLU(),
+ ),
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="nearest"),
+ nn.Conv2d(512, 256, kernel_size=3, padding=1),
+ nn.GroupNorm(32, 256),
+ nn.SiLU(),
+ ),
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="nearest"),
+ nn.Conv2d(256, 64, kernel_size=3, padding=1),
+ nn.GroupNorm(32, 64),
+ nn.SiLU(),
+ ),
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="nearest"),
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
+ nn.GroupNorm(16, 32),
+ nn.SiLU(),
+ ),
+ nn.Sequential(
+ nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.SiLU()
+ ),
+ ]
+ )
self.out_conv = nn.Conv2d(16, out_channels, kernel_size=3, padding=1)
@@ -520,8 +530,8 @@ def __init__(self):
def forward(self, x):
# x shape: [B, 4096, H/32, W/32]
- x = self.ps1(self.act1(self.conv1(x))) # -> [B, 256, H/8, W/8]
- x = self.ps2(self.conv2(x)) # -> [B, 3, H, W]
+ x = self.ps1(self.act1(self.conv1(x))) # -> [B, 256, H/8, W/8]
+ x = self.ps2(self.conv2(x)) # -> [B, 3, H, W]
return x
@@ -544,11 +554,12 @@ def __init__(self):
def forward(self, x):
# x shape: [B, 4096, H/32, W/32]
- x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16]
- x = self.act2(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8]
- x = self.conv3(self.ps3((x))) # -> [B, 3, H, W]
+ x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16]
+ x = self.act2(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8]
+ x = self.conv3(self.ps3((x))) # -> [B, 3, H, W]
return x
+
class PatchDecoder_preps1(nn.Module):
def __init__(self):
super().__init__()
@@ -566,10 +577,11 @@ def __init__(self):
def forward(self, x):
# x shape: [B, 4096, H/32, W/32]
- x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16]
- x = self.ps3(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8]
+ x = self.act1(self.conv1(self.ps1((x)))) # -> [B, 256, H/16, W/16]
+ x = self.ps3(self.conv2(self.ps2((x)))) # -> [B, 256, H/8, W/8]
return x
+
class ConvDecoder(nn.Module):
def __init__(self, input_dim=4096, hidden_dim=1024):
super().__init__()
diff --git a/lightllm/server/x2i_server/naive/modeling_neo_chat.py b/lightllm/server/x2i_server/naive/modeling_neo_chat.py
index 326e2a574d..fcd0cdf5f6 100644
--- a/lightllm/server/x2i_server/naive/modeling_neo_chat.py
+++ b/lightllm/server/x2i_server/naive/modeling_neo_chat.py
@@ -18,19 +18,29 @@
from .configuration_neo_chat import NEOChatConfig
from .modeling_neo_vit import NEOVisionModel
from .modeling_qwen3 import Qwen3ForCausalLM, create_block_causal_mask
-from .modeling_fm_modules import PositionEmbedding, TimestepEmbedder, FlowMatchingHead, RMSNorm, NerfEmbedder, SimpleMLPAdaLN, ConvDecoder
+from .modeling_fm_modules import (
+ PositionEmbedding,
+ TimestepEmbedder,
+ FlowMatchingHead,
+ RMSNorm,
+ NerfEmbedder,
+ SimpleMLPAdaLN,
+ ConvDecoder,
+)
logger = logging.get_logger(__name__)
-def version_cmp(v1, v2, op='eq'):
+def version_cmp(v1, v2, op="eq"):
import operator
from packaging import version
+
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))
+
def prepare_flash_kv_cache(
past_key_values,
current_len: int,
@@ -82,6 +92,7 @@ def prepare_flash_kv_cache(
layer.flash_k_cache = k_cache
layer.flash_v_cache = v_cache
+
def clear_flash_kv_cache(past_key_values):
if past_key_values is None:
return
@@ -132,9 +143,10 @@ def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
# Generate intra-image patch index (row-major order)
patch_id_within_image = torch.arange(N_total, device=device)
- patch_id_within_image = patch_id_within_image - torch.cumsum(
- torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0
- )[patch_to_sample]
+ patch_id_within_image = (
+ patch_id_within_image
+ - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample]
+ )
# Get H/W for each patch according to its image
W_per_patch = W[patch_to_sample]
@@ -146,8 +158,8 @@ def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
class NEOChatModel(PreTrainedModel):
config_class = NEOChatConfig
- main_input_name = 'pixel_values'
- base_model_prefix = 'language_model'
+ main_input_name = "pixel_values"
+ base_model_prefix = "language_model"
_supports_flash_attn_2 = True
supports_gradient_checkpointing = True
_no_split_modules = [
@@ -156,17 +168,17 @@ class NEOChatModel(PreTrainedModel):
]
# support transformers 4.51.+
- _tp_plan = ''
+ _tp_plan = ""
def __init__(self, config: NEOChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
super().__init__(config)
- assert version_cmp(transformers.__version__, '4.37.0', 'ge')
+ assert version_cmp(transformers.__version__, "4.37.0", "ge")
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
self.template = config.template
self.downsample_ratio = config.downsample_ratio
- config.llm_config._attn_implementation = 'eager'
+ config.llm_config._attn_implementation = "eager"
if vision_model is not None:
self.vision_model = vision_model
@@ -179,32 +191,33 @@ def __init__(self, config: NEOChatConfig, vision_model=None, language_model=None
self.language_model = Qwen3ForCausalLM(config.llm_config)
merge_size = int(1 / self.downsample_ratio)
- output_dim = 3*(patch_size*merge_size)**2
+ output_dim = 3 * (patch_size * merge_size) ** 2
llm_hidden_size = self.config.llm_config.hidden_size
self.use_deep_fm_head = self.config.fm_head_layers > 2
self.use_pixel_head = self.config.use_pixel_head
if self.use_deep_fm_head:
- fm_head = FlowMatchingHead(llm_hidden_size, output_dim, dim=self.config.fm_head_dim, layers=self.config.fm_head_layers, mlp_ratio=self.config.fm_head_mlp_ratio)
+ fm_head = FlowMatchingHead(
+ llm_hidden_size,
+ output_dim,
+ dim=self.config.fm_head_dim,
+ layers=self.config.fm_head_layers,
+ mlp_ratio=self.config.fm_head_mlp_ratio,
+ )
else:
fm_head = nn.Sequential(
- nn.Linear(llm_hidden_size, 4096, bias=True),
- nn.GELU(),
- nn.Linear(4096, output_dim, bias=True),
- )
+ nn.Linear(llm_hidden_size, 4096, bias=True),
+ nn.GELU(),
+ nn.Linear(4096, output_dim, bias=True),
+ )
timestep_embedder = TimestepEmbedder(llm_hidden_size)
self.fm_modules = nn.ModuleDict(
- {
- "vision_model_mot_gen": vision_model_mot_gen,
- "timestep_embedder": timestep_embedder,
- "fm_head": fm_head
- }
- )
+ {"vision_model_mot_gen": vision_model_mot_gen, "timestep_embedder": timestep_embedder, "fm_head": fm_head}
+ )
if self.use_pixel_head:
self.fm_modules["fm_head"] = ConvDecoder(llm_hidden_size)
-
self.concat_time_token_num = config.concat_time_token_num
self.time_token_id = 151682
self.noise_scale = config.noise_scale
@@ -222,27 +235,22 @@ def __init__(self, config: NEOChatConfig, vision_model=None, language_model=None
if self.add_noise_scale_embedding:
noise_scale_embedder = TimestepEmbedder(llm_hidden_size)
- self.fm_modules['noise_scale_embedder'] = noise_scale_embedder
-
-
+ self.fm_modules["noise_scale_embedder"] = noise_scale_embedder
self.img_context_token_id = None
self.img_start_token_id = 151670
# self.conv_template = get_conv_template(self.template)
# self.system_message = self.conv_template.system_message
-
def extract_feature(self, pixel_values, gen_model=False, grid_hw=None):
if gen_model:
- return self.fm_modules['vision_model_mot_gen'](pixel_values=pixel_values,
- output_hidden_states=False,
- return_dict=True,
- grid_hw=grid_hw).last_hidden_state
+ return self.fm_modules["vision_model_mot_gen"](
+ pixel_values=pixel_values, output_hidden_states=False, return_dict=True, grid_hw=grid_hw
+ ).last_hidden_state
else:
- return self.vision_model(pixel_values=pixel_values,
- output_hidden_states=False,
- return_dict=True,
- grid_hw=grid_hw).last_hidden_state
+ return self.vision_model(
+ pixel_values=pixel_values, output_hidden_states=False, return_dict=True, grid_hw=grid_hw
+ ).last_hidden_state
def patchify(self, images, patch_size, channel_first=False):
"""
@@ -253,11 +261,11 @@ def patchify(self, images, patch_size, channel_first=False):
x = images.reshape(shape=(images.shape[0], 3, h, patch_size, w, patch_size))
if channel_first:
- x = torch.einsum('nchpwq->nhwcpq', x)
+ x = torch.einsum("nchpwq->nhwcpq", x)
else:
- x = torch.einsum('nchpwq->nhwpqc', x)
+ x = torch.einsum("nchpwq->nhwpqc", x)
- x = x.reshape(shape=(images.shape[0], h * w, patch_size**2 * 3))
+ x = x.reshape(shape=(images.shape[0], h * w, patch_size ** 2 * 3))
return x
def unpatchify(sle, x, patch_size, h=None, w=None):
@@ -266,12 +274,12 @@ def unpatchify(sle, x, patch_size, h=None, w=None):
images: (N, 3, H, W)
"""
if h is None or w is None:
- h = w = int(x.shape[1]**.5)
+ h = w = int(x.shape[1] ** 0.5)
else:
h = h // patch_size
w = w // patch_size
x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, 3))
- x = torch.einsum('nhwpqc->nchpwq', x)
+ x = torch.einsum("nhwpqc->nchpwq", x)
images = x.reshape(shape=(x.shape[0], 3, h * patch_size, w * patch_size))
return images
@@ -316,15 +324,25 @@ def _build_t2i_image_indexes(self, token_h, token_w, text_len, device):
w_image = idx % token_w
return torch.stack([t_image, h_image, w_image], dim=0)
-
-
- def _t2i_predict_v(self, input_embeds, indexes_image, attn_mask, past_key_values, t, z,
- image_token_num, timestep_embeddings=None, image_size=None):
+ def _t2i_predict_v(
+ self,
+ input_embeds,
+ indexes_image,
+ attn_mask,
+ past_key_values,
+ t,
+ z,
+ image_token_num,
+ timestep_embeddings=None,
+ image_size=None,
+ ):
B, L = z.shape[0], z.shape[1]
outputs = self.language_model.model(
inputs_embeds=input_embeds,
- image_gen_indicators=torch.ones((input_embeds.shape[0], input_embeds.shape[1]), dtype=torch.bool, device=input_embeds.device),
+ image_gen_indicators=torch.ones(
+ (input_embeds.shape[0], input_embeds.shape[1]), dtype=torch.bool, device=input_embeds.device
+ ),
indexes=indexes_image,
attention_mask=attn_mask,
past_key_values=past_key_values,
@@ -341,44 +359,47 @@ def _t2i_predict_v(self, input_embeds, indexes_image, attn_mask, past_key_values
img_2d = torch.einsum("b h w c -> b c h w", img_reshaped)
img_2d = img_2d.contiguous().view(B, -1, token_h, token_w)
- smoothed_img_2d = self.fm_modules['fm_head'](img_2d)
+ smoothed_img_2d = self.fm_modules["fm_head"](img_2d)
- smoothed_reshaped = smoothed_img_2d.view(B, 3, token_h, self.patch_size * merge_size, token_w, self.patch_size * merge_size)
+ smoothed_reshaped = smoothed_img_2d.view(
+ B, 3, token_h, self.patch_size * merge_size, token_w, self.patch_size * merge_size
+ )
smoothed_reshaped = torch.einsum("b c h p w q -> b h w p q c", smoothed_reshaped)
- out_1d = smoothed_reshaped.contiguous().view(B, L, self.patch_size * merge_size * self.patch_size * merge_size * 3)
+ out_1d = smoothed_reshaped.contiguous().view(
+ B, L, self.patch_size * merge_size * self.patch_size * merge_size * 3
+ )
x_pred = out_1d
else:
if self.use_deep_fm_head:
x_pred = self.fm_modules["fm_head"](
- outputs.last_hidden_state[:, -image_token_num:].view(B*L, -1), t.repeat(B*L)
+ outputs.last_hidden_state[:, -image_token_num:].view(B * L, -1), t.repeat(B * L)
).view(B, L, -1)
else:
x_pred = self.fm_modules["fm_head"](
outputs.last_hidden_state[:, -image_token_num:].view(B, L, -1)
).view(B, L, -1)
-
v_pred = (x_pred - z) / (1 - t).clamp_min(self.config.t_eps)
return v_pred
-
@torch.no_grad()
- def it2i_generate(self,
- past_key_values_condition,
- past_key_values_text_uncondition,
- past_key_values_img_uncondition,
- text_lens,
- cfg_scale=1,
- img_cfg_scale=1,
- cfg_norm='none',
- enable_timestep_shift=True,
- timestep_shift=3,
- image_size=(256, 256),
- num_steps=30,
- cfg_interval=(0.1, 1.0),
- batch_size=1,
- t_eps=0.02,
- ):
+ def it2i_generate(
+ self,
+ past_key_values_condition,
+ past_key_values_text_uncondition,
+ past_key_values_img_uncondition,
+ text_lens,
+ cfg_scale=1,
+ img_cfg_scale=1,
+ cfg_norm="none",
+ enable_timestep_shift=True,
+ timestep_shift=3,
+ image_size=(256, 256),
+ num_steps=30,
+ cfg_interval=(0.1, 1.0),
+ batch_size=1,
+ t_eps=0.02,
+ ):
self.config.t_eps = t_eps
device, dtype = self.get_cache_device_dtype(past_key_values_condition)
@@ -394,12 +415,24 @@ def it2i_generate(self,
indexes_image_img_uncondition = self._build_t2i_image_indexes(token_h, token_w, S3, device=device)
for layer_idx in range(len(past_key_values_condition.layers)):
- past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:])
- past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[layer_idx].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:])
- past_key_values_text_uncondition.layers[layer_idx].keys = past_key_values_text_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].keys.shape[1:])
- past_key_values_text_uncondition.layers[layer_idx].values = past_key_values_text_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].values.shape[1:])
- past_key_values_img_uncondition.layers[layer_idx].keys = past_key_values_img_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].keys.shape[1:])
- past_key_values_img_uncondition.layers[layer_idx].values = past_key_values_img_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].values.shape[1:])
+ past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(
+ batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:]
+ )
+ past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[
+ layer_idx
+ ].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:])
+ past_key_values_text_uncondition.layers[layer_idx].keys = past_key_values_text_uncondition.layers[
+ layer_idx
+ ].keys.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].keys.shape[1:])
+ past_key_values_text_uncondition.layers[layer_idx].values = past_key_values_text_uncondition.layers[
+ layer_idx
+ ].values.expand(batch_size, *past_key_values_text_uncondition.layers[layer_idx].values.shape[1:])
+ past_key_values_img_uncondition.layers[layer_idx].keys = past_key_values_img_uncondition.layers[
+ layer_idx
+ ].keys.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].keys.shape[1:])
+ past_key_values_img_uncondition.layers[layer_idx].values = past_key_values_img_uncondition.layers[
+ layer_idx
+ ].values.expand(batch_size, *past_key_values_img_uncondition.layers[layer_idx].values.shape[1:])
prepare_flash_kv_cache(
past_key_values_condition,
@@ -417,31 +450,32 @@ def it2i_generate(self,
batch_size=batch_size,
)
-
# init noise image tokens
grid_h = image_size[1] // self.patch_size
grid_w = image_size[0] // self.patch_size
grid_hw = torch.tensor([[grid_h, grid_w]] * batch_size, device=device)
noise_scale = self.noise_scale
- if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'):
- noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2) / self.noise_scale_base_image_seq_len)
+ if self.noise_scale_mode in ("resolution", "dynamic", "dynamic_sqrt"):
+ noise_scale = math.sqrt((grid_h * grid_w) / (merge_size ** 2) / self.noise_scale_base_image_seq_len)
base = float(self.noise_scale_base_image_seq_len)
- scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base)
+ scale = math.sqrt((grid_h * grid_w) / (merge_size ** 2) / base)
noise_scale = scale * float(self.noise_scale)
- if self.noise_scale_mode == 'dynamic_sqrt':
+ if self.noise_scale_mode == "dynamic_sqrt":
noise_scale = math.sqrt(noise_scale)
noise_scale = min(noise_scale, self.noise_scale_max_value)
- image_prediction = noise_scale * torch.randn((batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype)
+ image_prediction = noise_scale * torch.randn(
+ (batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype
+ )
attention_mask_condition = {"full_attention": None}
attention_mask_text_uncondition = {"full_attention": None}
attention_mask_img_uncondition = {"full_attention": None}
- timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device)
+ timesteps = torch.linspace(0.0, 1.0, num_steps + 1, device=device)
if enable_timestep_shift:
- timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift)
+ timesteps = self._apply_time_schedule(timesteps, token_h * token_w, timestep_shift)
for step_i in range(num_steps):
t = timesteps[step_i]
@@ -450,16 +484,32 @@ def it2i_generate(self,
z = self.patchify(image_prediction, self.patch_size * merge_size)
image_input = self.patchify(image_prediction, self.patch_size, channel_first=True)
- image_embeds = self.extract_feature(image_input.view(batch_size * grid_h*grid_w, -1), gen_model=True, grid_hw=grid_hw).view(batch_size, token_h*token_w, -1)
- t_expanded = t.expand(batch_size*token_h*token_w)
- timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(batch_size, token_h*token_w, -1)
+ image_embeds = self.extract_feature(
+ image_input.view(batch_size * grid_h * grid_w, -1), gen_model=True, grid_hw=grid_hw
+ ).view(batch_size, token_h * token_w, -1)
+ t_expanded = t.expand(batch_size * token_h * token_w)
+ timestep_embeddings = self.fm_modules["timestep_embedder"](t_expanded).view(
+ batch_size, token_h * token_w, -1
+ )
if self.add_noise_scale_embedding:
- noise_scale_tensor = torch.full_like(t_expanded, noise_scale/self.noise_scale_max_value)
- noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(batch_size, token_h*token_w, -1)
+ noise_scale_tensor = torch.full_like(t_expanded, noise_scale / self.noise_scale_max_value)
+ noise_embeddings = self.fm_modules["noise_scale_embedder"](noise_scale_tensor).view(
+ batch_size, token_h * token_w, -1
+ )
timestep_embeddings += noise_embeddings
image_embeds = image_embeds + timestep_embeddings
- v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w, timestep_embeddings=timestep_embeddings,image_size=image_size)
+ v_pred_condition = self._t2i_predict_v(
+ image_embeds,
+ indexes_image_condition,
+ attention_mask_condition,
+ past_key_values_condition,
+ t,
+ z,
+ image_token_num=token_h * token_w,
+ timestep_embeddings=timestep_embeddings,
+ image_size=image_size,
+ )
if not use_cfg:
v_pred = v_pred_condition
elif cfg_scale == 1 and img_cfg_scale == 1:
@@ -489,7 +539,7 @@ def it2i_generate(self,
timestep_embeddings=timestep_embeddings,
image_size=image_size,
)
- v_pred = out_uncond + cfg_scale *(v_pred_condition - out_uncond)
+ v_pred = out_uncond + cfg_scale * (v_pred_condition - out_uncond)
else:
out_img_cond = self._t2i_predict_v(
image_embeds,
@@ -520,18 +570,17 @@ def it2i_generate(self,
)
if cfg_scale > 1 or img_cfg_scale > 1:
- if cfg_norm == 'global':
+ if cfg_norm == "global":
norm_v_condition = torch.norm(v_pred_condition, dim=(1, 2), keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=(1, 2), keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
- elif cfg_norm == 'channel':
+ elif cfg_norm == "channel":
norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
-
z = z + (t_next - t) * v_pred
image_prediction = self.unpatchify(z, self.patch_size * merge_size, image_size[1], image_size[0])
@@ -542,7 +591,6 @@ def it2i_generate(self,
return image_prediction
-
def get_cache_device_dtype(self, cache):
"""
Returns (device, dtype) of a DynamicCache.
@@ -553,20 +601,22 @@ def get_cache_device_dtype(self, cache):
raise ValueError("Cache is empty")
@torch.no_grad()
- def t2i_generate(self,
- past_key_values_condition,
- past_key_values_uncondition,
- text_lens,
- cfg_scale=1,
- timestep_shift=3,
- enable_timestep_shift=True,
- cfg_norm='none',
- image_size=(256, 256),
- num_steps=30,
- cfg_interval=(0.1, 1.0),
- batch_size=1,
- t_eps=0.02):
- assert cfg_norm in ['cfg_zero_star', 'global', 'none'], f"cfg_norm={cfg_norm}"
+ def t2i_generate(
+ self,
+ past_key_values_condition,
+ past_key_values_uncondition,
+ text_lens,
+ cfg_scale=1,
+ timestep_shift=3,
+ enable_timestep_shift=True,
+ cfg_norm="none",
+ image_size=(256, 256),
+ num_steps=30,
+ cfg_interval=(0.1, 1.0),
+ batch_size=1,
+ t_eps=0.02,
+ ):
+ assert cfg_norm in ["cfg_zero_star", "global", "none"], f"cfg_norm={cfg_norm}"
merge_size = int(1 / self.downsample_ratio)
self.config.t_eps = t_eps
@@ -580,10 +630,18 @@ def t2i_generate(self,
indexes_image_uncondition = self._build_t2i_image_indexes(token_h, token_w, S2, device=device)
for layer_idx in range(len(past_key_values_condition.layers)):
- past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:])
- past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[layer_idx].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:])
- past_key_values_uncondition.layers[layer_idx].keys = past_key_values_uncondition.layers[layer_idx].keys.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].keys.shape[1:])
- past_key_values_uncondition.layers[layer_idx].values = past_key_values_uncondition.layers[layer_idx].values.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].values.shape[1:])
+ past_key_values_condition.layers[layer_idx].keys = past_key_values_condition.layers[layer_idx].keys.expand(
+ batch_size, *past_key_values_condition.layers[layer_idx].keys.shape[1:]
+ )
+ past_key_values_condition.layers[layer_idx].values = past_key_values_condition.layers[
+ layer_idx
+ ].values.expand(batch_size, *past_key_values_condition.layers[layer_idx].values.shape[1:])
+ past_key_values_uncondition.layers[layer_idx].keys = past_key_values_uncondition.layers[
+ layer_idx
+ ].keys.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].keys.shape[1:])
+ past_key_values_uncondition.layers[layer_idx].values = past_key_values_uncondition.layers[
+ layer_idx
+ ].values.expand(batch_size, *past_key_values_uncondition.layers[layer_idx].values.shape[1:])
# prepare flash cache once
prepare_flash_kv_cache(
@@ -600,27 +658,29 @@ def t2i_generate(self,
# init noise image tokens
grid_h = image_size[1] // self.patch_size
grid_w = image_size[0] // self.patch_size
- grid_hw = torch.tensor([[grid_h, grid_w]]*batch_size, device=device)
+ grid_hw = torch.tensor([[grid_h, grid_w]] * batch_size, device=device)
noise_scale = self.noise_scale
- if self.noise_scale_mode in ("resolution", "dynamic", 'dynamic_sqrt'):
- noise_scale = math.sqrt((grid_h*grid_w)/(merge_size**2) / self.noise_scale_base_image_seq_len)
+ if self.noise_scale_mode in ("resolution", "dynamic", "dynamic_sqrt"):
+ noise_scale = math.sqrt((grid_h * grid_w) / (merge_size ** 2) / self.noise_scale_base_image_seq_len)
base = float(self.noise_scale_base_image_seq_len)
- scale = math.sqrt((grid_h*grid_w)/(merge_size**2)/base)
+ scale = math.sqrt((grid_h * grid_w) / (merge_size ** 2) / base)
noise_scale = scale * float(self.noise_scale)
- if self.noise_scale_mode == 'dynamic_sqrt':
+ if self.noise_scale_mode == "dynamic_sqrt":
noise_scale = math.sqrt(noise_scale)
noise_scale = min(noise_scale, self.noise_scale_max_value)
- image_prediction = noise_scale * torch.randn((batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype)
+ image_prediction = noise_scale * torch.randn(
+ (batch_size, 3, image_size[1], image_size[0]), device=device, dtype=dtype
+ )
attention_mask_condition = {"full_attention": None}
attention_mask_uncondition = {"full_attention": None}
- timesteps = torch.linspace(0.0, 1.0, num_steps+1, device=device)
+ timesteps = torch.linspace(0.0, 1.0, num_steps + 1, device=device)
if enable_timestep_shift:
- timesteps = self._apply_time_schedule(timesteps, token_h*token_w, timestep_shift)
+ timesteps = self._apply_time_schedule(timesteps, token_h * token_w, timestep_shift)
for step_i in range(num_steps):
t = timesteps[step_i]
@@ -628,43 +688,67 @@ def t2i_generate(self,
z = self.patchify(image_prediction, self.patch_size * merge_size)
image_input = self.patchify(image_prediction, self.patch_size, channel_first=True)
- image_embeds = self.extract_feature(image_input.view(batch_size * grid_h*grid_w, -1), gen_model=True, grid_hw=grid_hw).view(batch_size, token_h*token_w, -1)
- t_expanded = t.expand(batch_size*token_h*token_w)
- timestep_embeddings = self.fm_modules['timestep_embedder'](t_expanded).view(batch_size, token_h*token_w, -1)
+ image_embeds = self.extract_feature(
+ image_input.view(batch_size * grid_h * grid_w, -1), gen_model=True, grid_hw=grid_hw
+ ).view(batch_size, token_h * token_w, -1)
+ t_expanded = t.expand(batch_size * token_h * token_w)
+ timestep_embeddings = self.fm_modules["timestep_embedder"](t_expanded).view(
+ batch_size, token_h * token_w, -1
+ )
if self.add_noise_scale_embedding:
noise_scale_tensor = torch.full_like(t_expanded, noise_scale / self.noise_scale_max_value)
- noise_embeddings = self.fm_modules['noise_scale_embedder'](noise_scale_tensor).view(batch_size, token_h*token_w, -1)
+ noise_embeddings = self.fm_modules["noise_scale_embedder"](noise_scale_tensor).view(
+ batch_size, token_h * token_w, -1
+ )
timestep_embeddings += noise_embeddings
image_embeds = image_embeds + timestep_embeddings
-
- v_pred_condition = self._t2i_predict_v(image_embeds, indexes_image_condition, attention_mask_condition, past_key_values_condition, t, z, image_token_num=token_h*token_w,
- timestep_embeddings=timestep_embeddings, image_size=image_size)
-
+ v_pred_condition = self._t2i_predict_v(
+ image_embeds,
+ indexes_image_condition,
+ attention_mask_condition,
+ past_key_values_condition,
+ t,
+ z,
+ image_token_num=token_h * token_w,
+ timestep_embeddings=timestep_embeddings,
+ image_size=image_size,
+ )
if t >= cfg_interval[0] and t <= cfg_interval[1] and cfg_scale > 1:
- v_pred_uncondition = self._t2i_predict_v(image_embeds, indexes_image_uncondition, attention_mask_uncondition, past_key_values_uncondition, t, z, image_token_num=token_h*token_w,
- timestep_embeddings=timestep_embeddings, image_size=image_size)
- if cfg_norm == 'cfg_zero_star':
+ v_pred_uncondition = self._t2i_predict_v(
+ image_embeds,
+ indexes_image_uncondition,
+ attention_mask_uncondition,
+ past_key_values_uncondition,
+ t,
+ z,
+ image_token_num=token_h * token_w,
+ timestep_embeddings=timestep_embeddings,
+ image_size=image_size,
+ )
+ if cfg_norm == "cfg_zero_star":
positive_flat = v_pred_condition.view(batch_size, -1)
negative_flat = v_pred_uncondition.view(batch_size, -1)
- alpha = optimized_scale(positive_flat,negative_flat)
+ alpha = optimized_scale(positive_flat, negative_flat)
alpha = alpha.view(batch_size, *([1] * (len(v_pred_condition.shape) - 1)))
alpha = alpha.to(positive_flat.dtype)
- if (step_i <= 0):
- v_pred = v_pred_condition*0.
+ if step_i <= 0:
+ v_pred = v_pred_condition * 0.0
else:
- v_pred = v_pred_uncondition * alpha + cfg_scale * (v_pred_condition - v_pred_uncondition * alpha)
+ v_pred = v_pred_uncondition * alpha + cfg_scale * (
+ v_pred_condition - v_pred_uncondition * alpha
+ )
else:
v_pred = v_pred_uncondition + cfg_scale * (v_pred_condition - v_pred_uncondition)
- if cfg_norm == 'global':
- norm_v_condition = torch.norm(v_pred_condition, dim=(1,2), keepdim=True)
- norm_v_cfg = torch.norm(v_pred, dim=(1,2), keepdim=True)
+ if cfg_norm == "global":
+ norm_v_condition = torch.norm(v_pred_condition, dim=(1, 2), keepdim=True)
+ norm_v_cfg = torch.norm(v_pred, dim=(1, 2), keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
v_pred = v_pred * scale
- elif cfg_norm == 'channel':
+ elif cfg_norm == "channel":
norm_v_condition = torch.norm(v_pred_condition, dim=-1, keepdim=True)
norm_v_cfg = torch.norm(v_pred, dim=-1, keepdim=True)
scale = (norm_v_condition / (norm_v_cfg + 1e-8)).clamp(min=0, max=1.0)
@@ -682,7 +766,6 @@ def t2i_generate(self,
return image_prediction
-
@property
def lm_head(self):
return self.language_model.get_output_embeddings()
@@ -700,25 +783,29 @@ def set_output_embeddings(self, value):
return self.language_model.set_output_embeddings(value)
def get_thw_indexes(self, input_ids, grid_hw=None):
- img_start_shift = torch.cat([torch.zeros(1, dtype=torch.long).to(input_ids.device),
- (input_ids == self.img_start_token_id).long()], dim=0)[:-1]
+ img_start_shift = torch.cat(
+ [torch.zeros(1, dtype=torch.long).to(input_ids.device), (input_ids == self.img_start_token_id).long()],
+ dim=0,
+ )[:-1]
not_img_token = (input_ids != self.img_context_token_id).long()
- t_indexes = ((img_start_shift + not_img_token).cumsum(0) - 1)
+ t_indexes = (img_start_shift + not_img_token).cumsum(0) - 1
h_indexes = torch.zeros_like(t_indexes).to(t_indexes.device)
w_indexes = torch.zeros_like(t_indexes).to(t_indexes.device)
if grid_hw is not None:
- selected = (input_ids == self.img_context_token_id)
+ selected = input_ids == self.img_context_token_id
if selected.long().sum() > 0:
abs_pos_w, abs_pos_h = build_abs_positions_from_grid_hw(
- grid_hw // int(1 / self.downsample_ratio), device=t_indexes.device)
+ grid_hw // int(1 / self.downsample_ratio), device=t_indexes.device
+ )
h_indexes[selected] = abs_pos_h.to(t_indexes.device, t_indexes.dtype)
w_indexes[selected] = abs_pos_w.to(t_indexes.device, t_indexes.dtype)
return torch.stack([t_indexes, h_indexes, w_indexes], dim=0)
NORM_MEAN = [0.5, 0.5, 0.5]
-NORM_STD = [0.5, 0.5, 0.5]
+NORM_STD = [0.5, 0.5, 0.5]
+
class NEOX2I:
def __init__(self, model_path, device):
@@ -736,7 +823,7 @@ def _denorm(self, x: torch.Tensor, mean=NORM_MEAN, std=NORM_STD):
x: [B,3,H,W] normalized ((img-mean)/std). returns [0,1] clamped.
"""
mean = torch.tensor(mean, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
- std = torch.tensor(std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
+ std = torch.tensor(std, device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
return (x * std + mean).clamp(0, 1)
def _get_dynamic_cache(self, past_kv):
@@ -748,15 +835,17 @@ def _get_dynamic_cache(self, past_kv):
for layer_idx in range(L):
k = past_kv[layer_idx][0].unsqueeze(0).to(self.device, non_blocking=True)
v = past_kv[layer_idx][1].unsqueeze(0).to(self.device, non_blocking=True)
- past_kv_dc.update(key_states=k, value_states=v, layer_idx=layer_idx,)
+ past_kv_dc.update(
+ key_states=k,
+ value_states=v,
+ layer_idx=layer_idx,
+ )
return past_kv_dc
-
def t2i(self, past_kv, past_kv_txt, param: X2IParams):
past_kv_dc = self._get_dynamic_cache(past_kv)
past_kv_txt_dc = self._get_dynamic_cache(past_kv_txt)
- text_lens = (param.past_kvcache.get_compressed_len(),
- param.past_kvcache_text.get_compressed_len())
+ text_lens = (param.past_kvcache.get_compressed_len(), param.past_kvcache_text.get_compressed_len())
output = self.model.t2i_generate(
past_key_values_condition=past_kv_dc,
past_key_values_uncondition=past_kv_txt_dc,
@@ -766,7 +855,8 @@ def t2i(self, past_kv, past_kv_txt, param: X2IParams):
image_size=(param.width, param.height),
num_steps=param.steps,
batch_size=param.num_images,
- timestep_shift=param.timestep_shift)
+ timestep_shift=param.timestep_shift,
+ )
return self._post_process(output)
@@ -774,19 +864,18 @@ def _post_process(self, output):
images = self._denorm(output)
images = (images.clamp(0, 1) * 255.0).round().to(torch.uint8).cpu()
- base64_images = [
- base64.b64encode(io.encode_jpeg(img).numpy()).decode("utf-8")
- for img in images
- ]
+ base64_images = [base64.b64encode(io.encode_jpeg(img).numpy()).decode("utf-8") for img in images]
return base64_images
def it2i(self, past_kv, past_kv_txt, past_kv_img, param: X2IParams):
past_kv_dc = self._get_dynamic_cache(past_kv)
past_kv_txt_dc = self._get_dynamic_cache(past_kv_txt)
past_kv_img_dc = self._get_dynamic_cache(past_kv_img)
- text_lens = (param.past_kvcache.get_compressed_len(),
- param.past_kvcache_text.get_compressed_len(),
- param.past_kvcache_img.get_compressed_len())
+ text_lens = (
+ param.past_kvcache.get_compressed_len(),
+ param.past_kvcache_text.get_compressed_len(),
+ param.past_kvcache_img.get_compressed_len(),
+ )
output = self.model.it2i_generate(
past_key_values_condition=past_kv_dc,
past_key_values_text_uncondition=past_kv_txt_dc,
diff --git a/lightllm/server/x2i_server/naive/modeling_neo_vit.py b/lightllm/server/x2i_server/naive/modeling_neo_vit.py
index 63ded83e7d..b38a2677e5 100644
--- a/lightllm/server/x2i_server/naive/modeling_neo_vit.py
+++ b/lightllm/server/x2i_server/naive/modeling_neo_vit.py
@@ -9,9 +9,7 @@
from .configuration_neo_vit import NEOVisionConfig
-def precompute_rope_freqs_sincos(
- dim: int, max_position: int, base: float = 10000.0, device=None
-):
+def precompute_rope_freqs_sincos(dim: int, max_position: int, base: float = 10000.0, device=None):
"""预计算 RoPE 的 cos 和 sin 值 (1D)。"""
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
t = torch.arange(max_position, device=device).type_as(inv_freq)
@@ -40,9 +38,10 @@ def build_abs_positions_from_grid_hw(grid_hw: torch.Tensor, device=None):
# Generate intra-image patch index (row-major order)
patch_id_within_image = torch.arange(N_total, device=device)
- patch_id_within_image = patch_id_within_image - torch.cumsum(
- torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0
- )[patch_to_sample]
+ patch_id_within_image = (
+ patch_id_within_image
+ - torch.cumsum(torch.cat([torch.tensor([0], device=device), N[:-1]]), dim=0)[patch_to_sample]
+ )
# Get H/W for each patch according to its image
W_per_patch = W[patch_to_sample]
@@ -63,8 +62,8 @@ def apply_rotary_emb_1d(
# positions: (..., seq_len)
# cos_cached: (max_pos, dim_part / 2)
- cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2)
- sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2)
+ cos = cos_cached[positions] # Shape: (positions.shape, dim_part / 2)
+ sin = sin_cached[positions] # Shape: (positions.shape, dim_part / 2)
x1 = x[..., 0::2]
x2 = x[..., 1::2]
@@ -85,7 +84,7 @@ def apply_2d_rotary_pos_emb(
cos_cached_y: torch.Tensor,
sin_cached_y: torch.Tensor,
abs_positions_x: torch.Tensor,
- abs_positions_y: torch.Tensor
+ abs_positions_y: torch.Tensor,
):
"""应用2D RoPE到输入张量x。"""
dim = x.shape[-1]
@@ -97,13 +96,9 @@ def apply_2d_rotary_pos_emb(
x_part_2 = x[..., dim_half:]
# 将与 abs_positions_x 相关的旋转应用于 x_part_1
- rotated_part_1 = apply_rotary_emb_1d(
- x_part_1, cos_cached_x, sin_cached_x, abs_positions_x
- )
+ rotated_part_1 = apply_rotary_emb_1d(x_part_1, cos_cached_x, sin_cached_x, abs_positions_x)
# 将与 abs_positions_y 相关的旋转应用于 x_part_2
- rotated_part_2 = apply_rotary_emb_1d(
- x_part_2, cos_cached_y, sin_cached_y, abs_positions_y
- )
+ rotated_part_2 = apply_rotary_emb_1d(x_part_2, cos_cached_y, sin_cached_y, abs_positions_y)
# 将它们重新拼接起来。确保顺序与你分割时一致。
return torch.cat((rotated_part_1, rotated_part_2), dim=-1)
@@ -123,10 +118,16 @@ def __init__(self, config: NEOVisionConfig):
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
- in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
)
self.dense_embedding = nn.Conv2d(
- in_channels=self.embed_dim, out_channels=self.llm_embed_dim, kernel_size=self.downsample_factor, stride=self.downsample_factor
+ in_channels=self.embed_dim,
+ out_channels=self.llm_embed_dim,
+ kernel_size=self.downsample_factor,
+ stride=self.downsample_factor,
)
self.gelu = nn.GELU()
@@ -149,11 +150,13 @@ def _apply_2d_rotary_pos_emb(self, patch_embeds, grid_hw):
"""
abs_pos_x, abs_pos_y = build_abs_positions_from_grid_hw(grid_hw, device=patch_embeds.device)
embeddings = apply_2d_rotary_pos_emb(
- patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32
- self.cos_cached_x, self.sin_cached_x,
- self.cos_cached_y, self.sin_cached_y,
+ patch_embeds.to(torch.float32), # RoPE calculations are often more stable in float32
+ self.cos_cached_x,
+ self.sin_cached_x,
+ self.cos_cached_y,
+ self.sin_cached_y,
abs_pos_x,
- abs_pos_y
+ abs_pos_y,
).to(self.patch_embedding.weight.dtype)
return embeddings
@@ -164,14 +167,14 @@ def forward(self, pixel_values: torch.FloatTensor, grid_hw=None) -> torch.Tensor
3,
self.patch_size,
self.patch_size,
- ) # [28072, 768] -> [28072, 3, 16, 16]
+ ) # [28072, 768] -> [28072, 3, 16, 16]
patch_embeds = self.gelu(self.patch_embedding(pixel_values)).view(-1, self.embed_dim)
self.cos_cached_x = self.cos_cached_x.to(patch_embeds.device)
self.sin_cached_x = self.sin_cached_x.to(patch_embeds.device)
self.cos_cached_y = self.cos_cached_y.to(patch_embeds.device)
self.sin_cached_y = self.sin_cached_y.to(patch_embeds.device)
- patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) # [28072, 1024]
- assert (grid_hw[:,0] * grid_hw[:,1]).sum() == patch_embeds.shape[0]
+ patch_embeds = self._apply_2d_rotary_pos_emb(patch_embeds, grid_hw) # [28072, 1024]
+ assert (grid_hw[:, 0] * grid_hw[:, 1]).sum() == patch_embeds.shape[0]
patches_list = []
cur_position = 0
@@ -186,18 +189,18 @@ def forward(self, pixel_values: torch.FloatTensor, grid_hw=None) -> torch.Tensor
embeddings = torch.cat(patches_list, dim=0) # (N_total // downsample_factor**2, C)
assert cur_position == patch_embeds.shape[0]
- assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor**2)
+ assert embeddings.shape[0] == int(patch_embeds.shape[0] / self.downsample_factor ** 2)
return embeddings
class NEOVisionModel(PreTrainedModel):
- main_input_name = 'pixel_values'
+ main_input_name = "pixel_values"
_supports_flash_attn_2 = True
supports_gradient_checkpointing = True
config_class = NEOVisionConfig
# support transformers 4.51.+
- _tp_plan = ''
+ _tp_plan = ""
def __init__(self, config: NEOVisionConfig):
super().__init__(config)
@@ -206,12 +209,12 @@ def __init__(self, config: NEOVisionConfig):
self.embeddings = NEOVisionEmbeddings(config)
def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- pixel_embeds: Optional[torch.FloatTensor] = None,
- grid_hw: Optional[torch.Tensor] = None
+ self,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_embeds: Optional[torch.FloatTensor] = None,
+ grid_hw: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -219,7 +222,7 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and pixel_embeds is None:
- raise ValueError('You have to specify pixel_values or pixel_embeds')
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
if pixel_embeds is not None:
hidden_states = pixel_embeds
diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
index 4abfde5a83..ec2d9d1760 100644
--- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
+++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
@@ -148,9 +148,7 @@ def _build_inputs(
n_spans = int(torch.randint(1, num_image_spans_max + 1, (1,), generator=g).item())
start_pack = int(b_start_loc[i].item())
for _ in range(n_spans):
- span_len = int(
- torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item()
- )
+ span_len = int(torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item())
span_len = min(span_len, M)
s_rel = int(torch.randint(0, M - span_len + 1, (1,), generator=g).item())
b_image_token_tag[start_pack + s_rel : start_pack + s_rel + span_len] = True
diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py
index 9a635f0445..41f6b476ba 100644
--- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py
+++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_fa3_neo.py
@@ -172,9 +172,7 @@ def _build_inputs(
n_spans = int(torch.randint(1, num_image_spans_max + 1, (1,), generator=g).item())
start_pack = int(b_q_start_loc[i].item())
for _ in range(n_spans):
- span_len = int(
- torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item()
- )
+ span_len = int(torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item())
span_len = min(span_len, M)
s_rel = int(torch.randint(0, M - span_len + 1, (1,), generator=g).item())
b_image_token_tag[start_pack + s_rel : start_pack + s_rel + span_len] = True
@@ -214,9 +212,9 @@ def _fa3_prefill_with_image_tag(inputs: dict) -> torch.Tensor:
device = q.device
# Build page_table[b, p] = req_to_token_indexs[b_req_idx[b], p].
- page_table = inputs["req_to_token_indexs"][
- inputs["b_req_idx"].long(), : inputs["max_seq_len_in_batch"]
- ].to(torch.int32)
+ page_table = inputs["req_to_token_indexs"][inputs["b_req_idx"].long(), : inputs["max_seq_len_in_batch"]].to(
+ torch.int32
+ )
q_seq_lens_t = inputs["b_seq_len"].to(torch.int32) - inputs["b_prompt_cache_len"].to(torch.int32)
@@ -366,9 +364,7 @@ def _run_case(
(3, 8, 2, 128, torch.bfloat16, 6, 8, 1024),
],
)
-def test_fa3_neo_prefill_with_image_tag(
- batch, Hq, Hk, D, dtype, seed, max_q_seq_len, max_prompt_cache_len
-):
+def test_fa3_neo_prefill_with_image_tag(batch, Hq, Hk, D, dtype, seed, max_q_seq_len, max_prompt_cache_len):
abs_err, rel_err, cos = _run_case(
batch=batch,
Hq=Hq,
@@ -426,9 +422,7 @@ def fa3_run():
o_triton = torch.empty_like(inputs["q"])
# Kernel signature requires position_ids but the current masking path does
# not read it; zeros are fine for perf measurement.
- position_ids_0 = torch.zeros(
- inputs["q"].shape[0], dtype=torch.int32, device=inputs["q"].device
- )
+ position_ids_0 = torch.zeros(inputs["q"].shape[0], dtype=torch.int32, device=inputs["q"].device)
def triton_run():
context_attention_fwd_neo(
diff --git a/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py b/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py
index ca6c5043c9..2bcced470b 100644
--- a/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py
+++ b/unit_tests/common/kv_trans_kernel/test_kv_trans_from_gpu.py
@@ -5,10 +5,7 @@
# =========================================================
# GPU guard
-pytestmark = pytest.mark.skipif(
- not torch.cuda.is_available(),
- reason="CUDA not available"
-)
+pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
# =========================================================
# 工具函数:生成 GPU KV cache
def make_kv(L, T, H, D, device="cuda"):
@@ -20,14 +17,14 @@ def make_kv(L, T, H, D, device="cuda"):
# Test 1: 基础功能(每 token 对应一个 page slot)
def test_basic_copy():
L, T, H, D = 2, 8, 4, 16
- B = 1 # token_block_size
- P = 4 # all_page_num
+ B = 1 # token_block_size
+ P = 4 # all_page_num
gpu_kv = make_kv(L, T, H, D)
cpu_kv = torch.zeros((P, L, B, H, D), dtype=torch.float32, pin_memory=True)
token_indexes = torch.tensor([1, 3, 5, 7], device="cuda")
- page_indexes = torch.tensor([0, 1, 2, 3], device="cuda")
+ page_indexes = torch.tensor([0, 1, 2, 3], device="cuda")
offload_gpu_kv_to_cpu_all(
token_indexes,
@@ -43,10 +40,10 @@ def test_basic_copy():
for i in range(len(token_indexes)):
token = token_indexes[i].item()
- page = page_indexes[i].item()
+ page = page_indexes[i].item()
expected = gpu_kv[:, token, :, :].cpu() # (L,H,D)
- actual = cpu_kv[page, :, 0, :, :] # (L,H,D)
+ actual = cpu_kv[page, :, 0, :, :] # (L,H,D)
assert torch.allclose(expected, actual)
@@ -62,7 +59,7 @@ def test_random_tokens():
cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True)
token_indexes = torch.tensor([10, 2, 7, 15, 0, 3], device="cuda")
- page_indexes = torch.arange(6, device="cuda")
+ page_indexes = torch.arange(6, device="cuda")
offload_gpu_kv_to_cpu_all(
token_indexes,
@@ -80,10 +77,7 @@ def test_random_tokens():
t = token_indexes[i].item()
p = page_indexes[i].item()
- assert torch.allclose(
- cpu_kv[p, :, 0, :, :],
- gpu_kv[:, t, :, :].cpu()
- )
+ assert torch.allclose(cpu_kv[p, :, 0, :, :], gpu_kv[:, t, :, :].cpu())
# =========================================================
@@ -100,7 +94,7 @@ def test_with_scale():
cpu_scale = torch.zeros((P, L, B, H, D // 8), pin_memory=True)
token_indexes = torch.tensor([1, 2, 3], device="cuda")
- page_indexes = torch.tensor([0, 1, 2], device="cuda")
+ page_indexes = torch.tensor([0, 1, 2], device="cuda")
offload_gpu_kv_to_cpu_all(
token_indexes,
@@ -119,16 +113,10 @@ def test_with_scale():
p = page_indexes[i].item()
# KV
- assert torch.allclose(
- cpu_kv[p, :, 0, :, :],
- gpu_kv[:, t, :, :].cpu()
- )
+ assert torch.allclose(cpu_kv[p, :, 0, :, :], gpu_kv[:, t, :, :].cpu())
# scale
- assert torch.allclose(
- cpu_scale[p, :, 0, :],
- gpu_scale[:, t, :].cpu()
- )
+ assert torch.allclose(cpu_scale[p, :, 0, :], gpu_scale[:, t, :].cpu())
# =========================================================
@@ -144,11 +132,14 @@ def test_tp_split():
cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True)
token_indexes = torch.tensor([1, 2], device="cuda")
- page_indexes = torch.tensor([0, 1], device="cuda")
+ page_indexes = torch.tensor([0, 1], device="cuda")
offload_gpu_kv_to_cpu_all(
- token_indexes, gpu_kv, None,
- cpu_kv, None,
+ token_indexes,
+ gpu_kv,
+ None,
+ cpu_kv,
+ None,
page_indexes,
tp_index=0,
tp_world_size=tp_world_size,
@@ -156,8 +147,11 @@ def test_tp_split():
)
offload_gpu_kv_to_cpu_all(
- token_indexes, gpu_kv, None,
- cpu_kv, None,
+ token_indexes,
+ gpu_kv,
+ None,
+ cpu_kv,
+ None,
page_indexes,
tp_index=1,
tp_world_size=tp_world_size,
@@ -170,15 +164,9 @@ def test_tp_split():
t = token_indexes[i].item()
p = page_indexes[i].item()
- assert torch.allclose(
- cpu_kv[p, :, 0, :split, :],
- gpu_kv[:, t, :split, :].cpu()
- )
+ assert torch.allclose(cpu_kv[p, :, 0, :split, :], gpu_kv[:, t, :split, :].cpu())
- assert torch.allclose(
- cpu_kv[p, :, 0, split:, :],
- gpu_kv[:, t, split:, :].cpu()
- )
+ assert torch.allclose(cpu_kv[p, :, 0, split:, :], gpu_kv[:, t, split:, :].cpu())
# =========================================================
@@ -192,7 +180,7 @@ def test_empty():
cpu_kv = torch.zeros((P, L, B, H, D), pin_memory=True)
token_indexes = torch.tensor([], dtype=torch.long, device="cuda")
- page_indexes = torch.tensor([], dtype=torch.long, device="cuda")
+ page_indexes = torch.tensor([], dtype=torch.long, device="cuda")
offload_gpu_kv_to_cpu_all(
token_indexes,
@@ -208,5 +196,6 @@ def test_empty():
assert torch.all(cpu_kv == 0)
+
if __name__ == "__main__":
- pytest.main()
\ No newline at end of file
+ pytest.main()
From 2f9489c3816690d746a2e35ab2919fbe371d4ae2 Mon Sep 17 00:00:00 2001
From: shihaobai <1798930569@qq.com>
Date: Wed, 6 May 2026 14:59:11 +0800
Subject: [PATCH 39/41] add seed rng session
---
lightllm/server/api_openai.py | 2 +
lightllm/server/core/objs/x2i_params.py | 6 ++
lightllm/server/httpserver/manager.py | 4 +
.../server/x2i_server/lightx2v/adapter.py | 26 ++++-
lightllm/server/x2i_server/manager.py | 29 +++++-
lightllm/server/x2i_server/rng_state_cache.py | 96 +++++++++++++++++++
6 files changed, 160 insertions(+), 3 deletions(-)
create mode 100644 lightllm/server/x2i_server/rng_state_cache.py
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index d137bba162..5f28827155 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -634,6 +634,7 @@ async def _get_text_generator_input(request: ChatCompletionRequest):
"n": request.n,
"best_of": request.n,
"add_special_tokens": False,
+ "seed": request.seed,
}
# Structured output handling
if request.response_format:
@@ -771,6 +772,7 @@ async def chat_completions_impl_v2(request: ChatCompletionRequestV2, raw_request
x2i_params = X2IParams()
x2i_params.init_from_image_config(request.image_config)
+ sampling_params.seed = x2i_params.seed
enable_thinking = request.chat_template_kwargs.get("enable_thinking", False)
print(f"x2i_params: {x2i_params} {image_only} {enable_thinking}", flush=True)
diff --git a/lightllm/server/core/objs/x2i_params.py b/lightllm/server/core/objs/x2i_params.py
index cac7b9abdd..9463ffc01f 100644
--- a/lightllm/server/core/objs/x2i_params.py
+++ b/lightllm/server/core/objs/x2i_params.py
@@ -45,6 +45,11 @@ class X2IParams(ctypes.Structure):
("past_kvcache_img", PastKVCachePageList),
("total_prompt_tokens", ctypes.c_int),
("request_id", ctypes.c_int64),
+ # session_id 用于在 x2i server 端按聊天会话缓存 RNG state,
+ # 解决多并发下不同 session 互相覆盖全局 torch/cuda RNG 的问题。
+ # 由 httpserver 在该 session 第一次生成图时绑定为该次的 request_id,
+ # 之后整段图文交错过程中保持不变(first_image=False 时不再覆盖)。
+ ("session_id", ctypes.c_int64),
]
_width: int = 1024
@@ -80,6 +85,7 @@ def _get(key, default):
self.dynamic_resolution = _get("dynamic_resolution", X2IParams._dynamic_resolution)
self.total_prompt_tokens = 0
self.request_id = 0
+ self.session_id = 0
self.has_updated_hw = False
self.first_image = True
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index 653117a071..9635965dd2 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -552,6 +552,10 @@ async def generation_wrapper(prompt, sample, multimodal, request):
# use the first request id as the gen image request id
x2i_req_id = generate_req_ids[0]
generation_params.request_id = x2i_req_id
+ # 第一次生成时,把本次 request_id 固化为 session_id,
+ # 后续同一会话的所有图都用同一个 session_id,供 x2i 端按会话缓存 RNG state。
+ if generation_params.first_image:
+ generation_params.session_id = x2i_req_id
req_status = X2IReqStatus(generation_params, generate_req_ids)
self.req_id_to_x2i_reqs[generation_params.request_id] = req_status
diff --git a/lightllm/server/x2i_server/lightx2v/adapter.py b/lightllm/server/x2i_server/lightx2v/adapter.py
index 95a795f617..60d04bb6d8 100644
--- a/lightllm/server/x2i_server/lightx2v/adapter.py
+++ b/lightllm/server/x2i_server/lightx2v/adapter.py
@@ -15,6 +15,7 @@
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.server.core.objs.x2i_params import X2IParams, X2IResponse, X2ICacheRelease, CfgNormType
from ..past_kv_cache_client import PastKVCacheClient
+from ..rng_state_cache import RngStateCache
logger = init_logger(__name__)
@@ -41,6 +42,11 @@ def __init__(self, args: StartArgs, rank: int, world_size: int):
self.task_dist_group = dist.new_group(backend="gloo", timeout=datetime.timedelta(days=30))
+ # Per-chat-session RNG snapshot, see lightllm/server/x2i_server/rng_state_cache.py.
+ # 各 worker 各自维护一份;由于本进程是单 stream 的串行 _process,且任务通过 broadcast
+ # 同步,因此各 rank 之间的 RNG 演进保持一致。
+ self.rng_state_cache = RngStateCache()
+
def _init_pipeline(self):
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(self.args.x2i_worker_nccl_port)
@@ -123,14 +129,30 @@ async def _process(self, param: X2IParams):
past_kv_cache_text,
past_kv_cache_img,
)
- seed = param.seed if param.first_image else None
- logger.info(f"seed: {seed} {param.seed} first_image: {param.first_image}")
+
+ session_id = param.session_id
+ if param.first_image:
+ # 本 session 第一张图:用传入的 seed 初始化全局 RNG
+ seed = param.seed
+ else:
+ # 后续图:恢复本 session 上次结束时的 RNG,避免被其他 session 的 seed_all 污染
+ restored = self.rng_state_cache.restore(session_id)
+ seed = None
+ if not restored:
+ logger.warning(
+ f"session {session_id} rng state miss (maybe expired or first call after restart), "
+ f"fallback to current global rng"
+ )
+ logger.info(f"seed: {seed} param.seed: {param.seed} first_image: {param.first_image} session_id: {session_id}")
image = self.pipe.generate(
seed=seed,
save_result_path="",
target_shape=[param.height, param.width],
)
+ # 保存当前 RNG state,供同一 session 下一张图使用
+ self.rng_state_cache.save(session_id)
+
return [image]
diff --git a/lightllm/server/x2i_server/manager.py b/lightllm/server/x2i_server/manager.py
index c48f283e79..0d169d3673 100644
--- a/lightllm/server/x2i_server/manager.py
+++ b/lightllm/server/x2i_server/manager.py
@@ -20,6 +20,7 @@
from lightllm.utils.dist_utils import set_current_device_id
from lightllm.utils.start_utils import start_submodule_processes
from .past_kv_cache_client import PastKVCacheClient
+from .rng_state_cache import RngStateCache
logger = init_logger(__name__)
@@ -62,6 +63,10 @@ def __init__(
self.past_kv_cache_client = PastKVCacheClient(only_create_meta_data=False, init_shm_data=True)
+ # Per-chat-session RNG snapshot, so concurrent sessions don't clobber each other's
+ # global torch / cuda RNG state between successive image generations.
+ self.rng_state_cache = RngStateCache()
+
async def wait_to_model_ready(self):
if self.world_size <= 1:
@@ -91,6 +96,22 @@ async def wait_to_model_ready(self):
args = [(self.args, rank, self.world_size) for rank in range(self.world_size)]
start_submodule_processes(funcs, args)
+ def _prepare_seed_for_iter(self, param: X2IParams, i: int, session_id: int):
+ """决定本轮 generate() 调用要传给 pipeline 的 seed。
+
+ - 一个 chat session 的第一张图(first_image=True 且 i==0):传种子,由 pipeline 内部 seed_all。
+ - 后续图(first_image=False 且 i==0):先把该 session 上次保存的 RNG state 恢复回来,
+ 再传 seed=None,让生成接着上一张图的 RNG 继续走。这样别的 session 中间穿插的 seed_all
+ 不会污染本 session 的 RNG。
+ - 同一次调用内部 i>0 的图:天然顺延全局 RNG 即可,不重置也不恢复。
+ """
+ if i == 0:
+ if param.first_image:
+ return param.seed
+ self.rng_state_cache.restore(session_id)
+ return None
+ return None
+
async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams):
if self.use_naive_x2i:
images = self.naive_x2i.t2i(past_kv_cache, past_kv_cache_text, param)
@@ -105,14 +126,18 @@ async def t2i_generate(self, past_kv_cache, past_kv_cache_text, param: X2IParams
timestep_shift=param.timestep_shift,
)
images = []
+ session_id = param.session_id
for i in range(param.num_images):
self.gen_pipe.runner.set_kvcache(past_kv_cache, past_kv_cache_text)
+ seed_arg = self._prepare_seed_for_iter(param, i, session_id)
image = self.gen_pipe.generate(
- seed=param.seed if param.first_image else None,
+ seed=seed_arg,
save_result_path="", # 返回base64,不需要指定路径了
target_shape=[param.height, param.width], # Height, Width
)
images.append(image)
+ # 保存当前 RNG state,下次同一 session 的图继续从这里走
+ self.rng_state_cache.save(session_id)
return images
async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_img, param: X2IParams):
@@ -131,6 +156,8 @@ async def it2i_generate(self, past_kv_cache, past_kv_cache_text, past_kv_cache_i
images = []
for i in range(param.num_images):
self.gen_pipe.runner.set_kvcache_i2i(past_kv_cache, past_kv_cache_text, past_kv_cache_img)
+ # it2i 的每张图都显式带种子(与 img_len+i 绑定),每次都会触发 seed_all,
+ # 不依赖也不会污染 t2i 的 session RNG state,无需缓存/恢复。
image = self.gen_pipe.generate(
seed=param.seed + param.past_kvcache_img.img_len + i,
save_result_path="", # 返回base64,不需要指定路径了
diff --git a/lightllm/server/x2i_server/rng_state_cache.py b/lightllm/server/x2i_server/rng_state_cache.py
new file mode 100644
index 0000000000..54a47f2e0d
--- /dev/null
+++ b/lightllm/server/x2i_server/rng_state_cache.py
@@ -0,0 +1,96 @@
+"""Per-session RNG state cache for the x2i server.
+
+In interleaved image-text scenarios, only the very first image of a chat
+session needs an explicit seed; subsequent images should continue from the
+RNG state left by the previous image generation. With multiple chat sessions
+running concurrently, the global torch / cuda / numpy / random RNG state is
+shared across sessions, so seeding for session B's first image would corrupt
+session A's continuation.
+
+This cache snapshots the RNG state after each generation per session_id and
+restores it before the next generation of the same session, so that each
+session sees a private, deterministic RNG stream.
+"""
+
+import random
+import time
+from typing import Dict, Tuple
+
+import numpy as np
+import torch
+
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+def capture_rng_state() -> Dict:
+ state = {
+ "random": random.getstate(),
+ "numpy": np.random.get_state(),
+ "torch": torch.get_rng_state(),
+ }
+ if torch.cuda.is_available():
+ state["torch_cuda"] = torch.cuda.get_rng_state_all()
+ return state
+
+
+def restore_rng_state(state: Dict) -> None:
+ random.setstate(state["random"])
+ np.random.set_state(state["numpy"])
+ torch.set_rng_state(state["torch"])
+ cuda_state = state.get("torch_cuda")
+ if cuda_state is not None and torch.cuda.is_available():
+ torch.cuda.set_rng_state_all(cuda_state)
+
+
+class RngStateCache:
+ """LRU + TTL cache for RNG snapshots keyed by chat session id.
+
+ - `save(session_id)`: snapshot current global RNG state for the session.
+ - `restore(session_id)`: restore previously saved state, returns True on hit.
+ - `discard(session_id)`: drop the session entry if any.
+ - Entries expire after `ttl_seconds` of inactivity, and the cache is hard
+ capped at `max_size` (oldest evicted first) to bound memory.
+ """
+
+ def __init__(self, max_size: int = 1024, ttl_seconds: float = 600.0):
+ self._max_size = max_size
+ self._ttl = ttl_seconds
+ # session_id -> (last_used_ts, state_dict)
+ self._cache: Dict[int, Tuple[float, Dict]] = {}
+
+ def save(self, session_id: int) -> None:
+ if session_id == 0:
+ return
+ self._cache[session_id] = (time.time(), capture_rng_state())
+ self._evict()
+
+ def restore(self, session_id: int) -> bool:
+ if session_id == 0:
+ return False
+ item = self._cache.get(session_id)
+ if item is None:
+ return False
+ ts, state = item
+ if self._ttl > 0 and time.time() - ts > self._ttl:
+ self._cache.pop(session_id, None)
+ logger.warning(f"RNG state for session {session_id} expired (TTL={self._ttl}s)")
+ return False
+ restore_rng_state(state)
+ self._cache[session_id] = (time.time(), state)
+ return True
+
+ def discard(self, session_id: int) -> None:
+ self._cache.pop(session_id, None)
+
+ def _evict(self) -> None:
+ if self._ttl > 0:
+ now = time.time()
+ expired = [sid for sid, (ts, _) in self._cache.items() if now - ts > self._ttl]
+ for sid in expired:
+ self._cache.pop(sid, None)
+ if len(self._cache) > self._max_size:
+ sorted_items = sorted(self._cache.items(), key=lambda kv: kv[1][0])
+ for sid, _ in sorted_items[: len(self._cache) - self._max_size]:
+ self._cache.pop(sid, None)
From 66b420c59e6d0eca12ef0c051cb40daa9f11bdc5 Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Fri, 8 May 2026 08:34:31 +0000
Subject: [PATCH 40/41] Fix Neo image-token attention scope in Triton prefill
---
.../layer_infer/transformer_layer_infer.py | 2 +-
lightllm/models/neo_chat_moe/infer_struct.py | 4 +-
.../layer_infer/transformer_layer_infer.py | 2 +-
.../context_attention_fwd_neo.py | 49 +++++++++---
.../triton_kernel/get_neo_position.py | 17 ++--
.../test_context_attention_fwd_neo.py | 79 +++++++++++--------
6 files changed, 99 insertions(+), 54 deletions(-)
diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
index c5c4f3c343..e9623b99a6 100644
--- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
@@ -113,7 +113,7 @@ def _context_attention_kernel(
infer_state.b_ready_cache_len,
infer_state.max_q_seq_len,
infer_state.req_manager.req_to_token_indexs,
- infer_state.b_image_token_tag,
+ infer_state.b_image_token_end,
)
return o_tensor
diff --git a/lightllm/models/neo_chat_moe/infer_struct.py b/lightllm/models/neo_chat_moe/infer_struct.py
index 1693bcb964..c52147cd9b 100644
--- a/lightllm/models/neo_chat_moe/infer_struct.py
+++ b/lightllm/models/neo_chat_moe/infer_struct.py
@@ -20,7 +20,7 @@ def __init__(self):
def init_some_extra_state(self, model: LlamaTpPartModel):
LlamaInferStateInfo.init_some_extra_state(self, model)
if self.is_prefill:
- self.b_image_token_tag = torch.zeros([self.position_ids.size(0)], dtype=torch.bool, device="cpu").cuda(
+ self.b_image_token_end = torch.zeros([self.position_ids.size(0)], dtype=torch.int32, device="cpu").cuda(
non_blocking=True
)
self.position_ids = self.get_neo_position(self.multimodal_params)
@@ -96,6 +96,6 @@ def get_neo_position(self, multimodal_params: List[dict]) -> torch.Tensor:
b_ready_cache_len=self.b_ready_cache_len,
b_q_seq_len=self.b_q_seq_len,
b_start_loc=self.b_q_start_loc,
- b_image_token_tag=self.b_image_token_tag,
+ b_image_token_end=self.b_image_token_end,
)
return position_ids
diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
index 66a320fe02..271c9592c3 100644
--- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
@@ -111,7 +111,7 @@ def _context_attention_kernel(
infer_state.b_ready_cache_len,
infer_state.max_q_seq_len,
infer_state.req_manager.req_to_token_indexs,
- infer_state.b_image_token_tag,
+ infer_state.b_image_token_end,
)
return o_tensor
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
index c9795e113a..8cefe486c0 100644
--- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
+++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
@@ -34,7 +34,7 @@ def _fwd_kernel(
stride_req_to_tokens_s,
kv_group_num,
b_prompt_cache_len,
- b_image_token_tag,
+ b_image_token_end,
H: tl.constexpr,
QK_HEAD_DIM: tl.constexpr,
V_HEAD_DIM: tl.constexpr,
@@ -79,18 +79,25 @@ def _fwd_kernel(
acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
# absolute q positions in the request
q_pos = prompt_cache_len + offs_m # [M]
- q_image_token_tag = tl.load(b_image_token_tag + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=False)
+ q_gid = tl.load(position_ids + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=-1)
+ q_image_end = tl.load(b_image_token_end + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=0)
- # per-M-block: only scan full K range if this block has image tokens
- has_image = tl.reduce_or(q_image_token_tag.to(tl.int32), axis=0) > 0
causal_end = tl.minimum(prompt_cache_len + block_start_loc + BLOCK_M, total_len)
- block_end_loc = tl.where(has_image, total_len, causal_end)
+ block_image_end = tl.minimum(tl.max(q_image_end, axis=0), total_len)
+ block_end_loc = tl.maximum(causal_end, block_image_end)
for start_n in range(0, block_end_loc, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k_pos = start_n + offs_n # [N]
k_valid = k_pos < block_end_loc
+ k_in_new = k_pos >= prompt_cache_len
+ k_rel = k_pos - prompt_cache_len
+ k_gid = tl.load(
+ position_ids + cur_batch_in_all_start_index + k_rel,
+ mask=k_valid & k_in_new,
+ other=-2,
+ )
# map logical pos -> mem_index (for K/V)
kv_loc = tl.load(
@@ -105,8 +112,14 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
- # mask: causal OR same gid (only possible inside NEW part)
- mask = ((q_pos[:, None] >= k_pos[None, :]) | q_image_token_tag[:, None]) & k_valid[None, :]
+ causal_mask = q_pos[:, None] >= k_pos[None, :]
+ image_mask = (
+ (q_image_end[:, None] > 0)
+ & k_in_new[None, :]
+ & (k_pos[None, :] < q_image_end[:, None])
+ & (q_gid[:, None] == k_gid[None, :])
+ )
+ mask = (causal_mask | image_mask) & k_valid[None, :]
qk = tl.where(mask, qk * sm_scale, -1.0e8)
# online softmax
@@ -151,7 +164,7 @@ def context_attention_fwd_neo(
b_prompt_cache_len,
max_input_len,
req_to_token_indexs,
- b_image_token_tag,
+ b_image_token_end,
):
# minimal safety: position_ids must cover packed q rows
assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0])
@@ -199,7 +212,7 @@ def context_attention_fwd_neo(
req_to_token_indexs.stride(1),
kv_group_num=kv_group_num,
b_prompt_cache_len=b_prompt_cache_len,
- b_image_token_tag=b_image_token_tag,
+ b_image_token_end=b_image_token_end,
H=head,
QK_HEAD_DIM=Lk,
V_HEAD_DIM=Lk,
@@ -215,6 +228,7 @@ def reference_attention(
k,
v,
position_ids_q, # 1D packed like q (only NEW tokens)
+ b_image_token_end,
b_req_idx,
b_start_loc,
b_seq_len,
@@ -240,6 +254,7 @@ def reference_attention(
q_start = int(b_start_loc[b].item())
q_blk = q[q_start : q_start + new_len] # [M, Hq, D]
gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M]
+ image_end_new = b_image_token_end[q_start : q_start + new_len].to(torch.int64) # [M]
# gather K/V for full request by logical pos -> mem_index
token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) # [L]
@@ -267,7 +282,12 @@ def reference_attention(
k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new
k_gid[k_in_new] = gid_new[k_rel[k_in_new]]
- allow = allow | (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :])
+ allow = allow | (
+ (image_end_new[:, None] > 0)
+ & k_in_new[None, :]
+ & (k_pos[None, :] < image_end_new[:, None])
+ & (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :])
+ )
# scores: [Hq, M, L]
q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D]
@@ -337,17 +357,20 @@ def make_test_case(
# position_ids_q: only NEW tokens, packed like q
position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32)
+ b_image_token_end = torch.zeros((sum_q,), device=device, dtype=torch.int32)
for b in range(batch):
M = int(new_lens[b].item())
+ P = int(prompt_lens[b].item())
start = int(b_start_loc[b].item())
- gid = torch.arange(M, device=device, dtype=torch.int32)
+ gid = torch.arange(P, P + M, device=device, dtype=torch.int32)
# make one repeated block inside NEW part to simulate image tokens
if M >= 4 and torch.rand((), device=device).item() > 0.3:
s = int(torch.randint(0, M - 2, (1,), device=device).item())
e = min(M, s + 3)
gid[s:e] = gid[s]
+ b_image_token_end[start + s : start + e] = P + e
position_ids_q[start : start + M] = gid
@@ -362,6 +385,7 @@ def make_test_case(
v,
o,
position_ids_q,
+ b_image_token_end,
b_req_idx,
b_start_loc,
b_seq_len,
@@ -378,6 +402,7 @@ def check_once(device="cuda", dtype=torch.float16, seed=0):
v,
o,
position_ids_q,
+ b_image_token_end,
b_req_idx,
b_start_loc,
b_seq_len,
@@ -398,6 +423,7 @@ def check_once(device="cuda", dtype=torch.float16, seed=0):
b_prompt_cache_len,
max_new_len,
req_to_token_indexs,
+ b_image_token_end,
)
ref = reference_attention(
@@ -405,6 +431,7 @@ def check_once(device="cuda", dtype=torch.float16, seed=0):
k,
v,
position_ids_q,
+ b_image_token_end,
b_req_idx,
b_start_loc,
b_seq_len,
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
index dc57870bf1..b2e1af9f8f 100644
--- a/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
+++ b/lightllm/models/neo_chat_moe/triton_kernel/get_neo_position.py
@@ -16,7 +16,7 @@ def _get_neo_position_triton(
b_ready_cache_len: torch.Tensor,
b_q_seq_len: torch.Tensor,
b_start_loc: torch.Tensor,
- b_image_token_tag: torch.Tensor,
+ b_image_token_end: torch.Tensor,
BLOCK_SIZE: tl.constexpr,
) -> torch.Tensor:
cur_batch = tl.program_id(0)
@@ -30,6 +30,7 @@ def _get_neo_position_triton(
local_image_start_idx = tl.load(b_image_start_idx + image_start_num + i)
image_start_idx = start_loc + local_image_start_idx - cache_len
image_len = tl.load(b_image_len + image_start_num + i)
+ image_end = local_image_start_idx + image_len
# image_h = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 1)
image_w = tl.load(b_image_thwd + (image_start_num + i) * b_image_thwd_stride0 + 2)
@@ -40,8 +41,8 @@ def _get_neo_position_triton(
h_pos = off // image_w
w_pos = off % image_w
tl.store(
- b_image_token_tag + off + image_start_idx,
- True,
+ b_image_token_end + off + image_start_idx,
+ image_end,
mask=(off < image_len)
& (off + local_image_start_idx - cache_len < q_seq_len)
& (local_image_start_idx - cache_len + off >= 0),
@@ -97,7 +98,7 @@ def get_neo_position_triton(
b_ready_cache_len: torch.Tensor,
b_q_seq_len: torch.Tensor,
b_start_loc: torch.Tensor,
- b_image_token_tag: torch.Tensor,
+ b_image_token_end: torch.Tensor,
) -> torch.Tensor:
batch_size = b_q_seq_len.shape[0]
@@ -116,7 +117,7 @@ def get_neo_position_triton(
b_ready_cache_len=b_ready_cache_len,
b_q_seq_len=b_q_seq_len,
b_start_loc=b_start_loc,
- b_image_token_tag=b_image_token_tag,
+ b_image_token_end=b_image_token_end,
BLOCK_SIZE=BLOCK_SIZE,
)
@@ -133,7 +134,7 @@ def test():
.expand(3, -1)
.contiguous()
)
- b_image_token_tag = torch.zeros([position_ids.size(1)], dtype=torch.bool, device="cuda")
+ b_image_token_end = torch.zeros([position_ids.size(1)], dtype=torch.int32, device="cuda")
position_ids[1:].zero_()
b_ready_cache_len = torch.tensor([0, 0], dtype=torch.int32, device="cuda")
b_q_seq_len = torch.tensor([7, 13], dtype=torch.int32, device="cuda")
@@ -148,10 +149,10 @@ def test():
b_ready_cache_len,
b_q_seq_len,
b_start_loc,
- b_image_token_tag,
+ b_image_token_end,
)
- print(b_image_token_tag)
+ print(b_image_token_end)
print(position_ids)
# old_value = torch.cat([position_ids[:, 2:7], position_ids[:, 7 + 2 :]], dim=1)
diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
index ec2d9d1760..730d4263ee 100644
--- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
+++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
@@ -5,11 +5,11 @@
batch element we gather K/V for the whole request (prompt + new tokens) via
``req_to_token_indexs`` and apply::
- allow[m, k] = (k <= q_pos[m]) OR image_tag[m] for k in [0, total)
+ allow[m, k] = (k <= q_pos[m]) OR same_image_group[m, k]
-i.e. normal queries are causal, image-token queries can see every real key in
-the request. If the Triton kernel disagrees with this reference, the kernel is
-wrong.
+i.e. normal queries are causal, and image-token queries can only see future
+tokens from the same image span. If the Triton kernel disagrees with this
+reference, the kernel is wrong.
Run directly for quick debugging:
@@ -35,12 +35,13 @@ def torch_reference_context_attention_neo(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
+ position_ids_0: torch.Tensor,
b_req_idx: torch.Tensor,
b_start_loc: torch.Tensor,
b_seq_len: torch.Tensor,
b_prompt_cache_len: torch.Tensor,
req_to_token_indexs: torch.Tensor,
- b_image_token_tag: torch.Tensor,
+ b_image_token_end: torch.Tensor,
) -> torch.Tensor:
device = q.device
dtype = q.dtype
@@ -61,7 +62,8 @@ def torch_reference_context_attention_neo(
q_start = int(b_start_loc[b].item())
q_blk = q[q_start : q_start + new] # [M, Hq, D]
- image_tag = b_image_token_tag[q_start : q_start + new].to(torch.bool)
+ q_gid = position_ids_0[q_start : q_start + new].to(torch.int64)
+ q_image_end = b_image_token_end[q_start : q_start + new].to(torch.int64)
token_locs = req_to_token_indexs[req_idx, :total].to(torch.int64)
k_blk = k[token_locs] # [total, Hk, D]
@@ -70,7 +72,16 @@ def torch_reference_context_attention_neo(
q_pos = torch.arange(prompt, total, device=device, dtype=torch.int64) # [M]
k_pos = torch.arange(0, total, device=device, dtype=torch.int64) # [total]
causal = k_pos[None, :] <= q_pos[:, None]
- allow = causal | image_tag[:, None]
+
+ k_gid = torch.full((total,), -1, device=device, dtype=torch.int64)
+ k_gid[prompt:total] = q_gid
+ same_image = (
+ (q_image_end[:, None] > 0)
+ & (k_pos[None, :] >= prompt)
+ & (k_pos[None, :] < q_image_end[:, None])
+ & (q_gid[:, None] == k_gid[None, :])
+ )
+ allow = causal | same_image
out_blk = torch.empty_like(q_blk)
for h in range(Hq):
@@ -137,34 +148,39 @@ def _build_inputs(
req_to_token_indexs[req_id, :L] = pool[p : p + L].to(torch.int32)
p += L
- # Randomly place contiguous image-token spans inside each batch's NEW region.
- b_image_token_tag = torch.zeros(sum_new, dtype=torch.bool)
+ b_image_token_end = torch.zeros(sum_new, dtype=torch.int32)
+ position_ids_0 = torch.empty(sum_new, dtype=torch.int32)
for i in range(batch):
M = int(new_lens[i].item())
+ P = int(prompt_lens[i].item())
+ start_pack = int(b_start_loc[i].item())
+ position_ids_0[start_pack : start_pack + M] = torch.arange(P, P + M, dtype=torch.int32)
if M < 2:
continue
if torch.rand((), generator=g).item() > image_prob:
continue
n_spans = int(torch.randint(1, num_image_spans_max + 1, (1,), generator=g).item())
- start_pack = int(b_start_loc[i].item())
+ cursor = 0
for _ in range(n_spans):
- span_len = int(torch.randint(1, max(2, image_span_len_max) + 1, (1,), generator=g).item())
- span_len = min(span_len, M)
- s_rel = int(torch.randint(0, M - span_len + 1, (1,), generator=g).item())
- b_image_token_tag[start_pack + s_rel : start_pack + s_rel + span_len] = True
+ remaining = M - cursor
+ if remaining <= 0:
+ break
+ gap = int(torch.randint(0, remaining, (1,), generator=g).item())
+ s_rel = cursor + gap
+ max_span_len = min(image_span_len_max, M - s_rel)
+ if max_span_len <= 0:
+ break
+ span_len = int(torch.randint(1, max_span_len + 1, (1,), generator=g).item())
+ e_rel = s_rel + span_len
+ image_gid = P + s_rel
+ image_end = P + e_rel
+ position_ids_0[start_pack + s_rel : start_pack + e_rel] = image_gid
+ b_image_token_end[start_pack + s_rel : start_pack + e_rel] = image_end
+ cursor = e_rel
b_seq_len = total_lens.to(torch.int32)
b_prompt_cache_len = prompt_lens.to(torch.int32)
- # position_ids[0]: kernel API still requires it even though its current
- # mask logic only reads b_image_token_tag.
- position_ids_0 = torch.empty(sum_new, dtype=torch.int32)
- for i in range(batch):
- M = int(new_lens[i].item())
- P = int(prompt_lens[i].item())
- s = int(b_start_loc[i].item())
- position_ids_0[s : s + M] = torch.arange(P, P + M, dtype=torch.int32)
-
q = torch.randn((sum_new, Hq, D), dtype=dtype, device=device)
k = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device)
v = torch.randn((kv_pool_size, Hk, D), dtype=dtype, device=device)
@@ -182,13 +198,13 @@ def _build_inputs(
b_prompt_cache_len=b_prompt_cache_len.to(device),
max_new_len=max_new_len,
req_to_token_indexs=req_to_token_indexs.to(device),
- b_image_token_tag=b_image_token_tag.to(device),
+ b_image_token_end=b_image_token_end.to(device),
new_lens=new_lens,
prompt_lens=prompt_lens,
)
-def _report_per_batch_error(out_triton, out_ref, new_lens, b_start_loc, image_tag, tag=""):
+def _report_per_batch_error(out_triton, out_ref, new_lens, b_start_loc, image_token_end, tag=""):
print(f"\n[{tag}] per-batch error breakdown (abs / rel / cos):")
for i in range(new_lens.shape[0]):
s = int(b_start_loc[i].item())
@@ -201,7 +217,7 @@ def _report_per_batch_error(out_triton, out_ref, new_lens, b_start_loc, image_ta
denom = b.abs().max().item() + 1e-6
rel_err = abs_err / denom
cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
- n_img = int(image_tag[s : s + m].sum().item())
+ n_img = int((image_token_end[s : s + m] > 0).sum().item())
print(
f" batch {i:02d} | M={m:4d} | image_tokens={n_img:4d} | "
f"max_abs={abs_err:.4e} | max_rel={rel_err:.4e} | cos={cos:.6f}"
@@ -249,7 +265,7 @@ def _run_case(
inputs["b_prompt_cache_len"],
inputs["max_new_len"],
inputs["req_to_token_indexs"],
- inputs["b_image_token_tag"],
+ inputs["b_image_token_end"],
)
out_triton = inputs["o"]
@@ -257,12 +273,13 @@ def _run_case(
inputs["q"],
inputs["k"],
inputs["v"],
+ inputs["position_ids_0"],
inputs["b_req_idx"],
inputs["b_start_loc"],
inputs["b_seq_len"],
inputs["b_prompt_cache_len"],
inputs["req_to_token_indexs"],
- inputs["b_image_token_tag"],
+ inputs["b_image_token_end"],
)
a = out_triton.float()
@@ -272,8 +289,8 @@ def _run_case(
rel_err = abs_err / denom
cos = torch.nn.functional.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
- n_image = int(inputs["b_image_token_tag"].sum().item())
- n_tokens = int(inputs["b_image_token_tag"].numel())
+ n_image = int((inputs["b_image_token_end"] > 0).sum().item())
+ n_tokens = int(inputs["b_image_token_end"].numel())
if verbose:
print(
f"\ncase: batch={batch} Hq={Hq} Hk={Hk} D={D} dtype={dtype} "
@@ -289,7 +306,7 @@ def _run_case(
out_ref,
inputs["new_lens"],
inputs["b_start_loc"],
- inputs["b_image_token_tag"],
+ inputs["b_image_token_end"],
tag=f"seed={seed}",
)
From c09bd8fd76ee73f7d66adabfca4ca37107e9aa0e Mon Sep 17 00:00:00 2001
From: WANDY666 <1060304770@qq.com>
Date: Fri, 8 May 2026 09:13:33 +0000
Subject: [PATCH 41/41] delete position ids
---
.../layer_infer/transformer_layer_infer.py | 1 -
.../layer_infer/transformer_layer_infer.py | 1 -
.../context_attention_fwd_neo.py | 51 +------------------
.../test_context_attention_fwd_neo.py | 22 +-------
4 files changed, 4 insertions(+), 71 deletions(-)
diff --git a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
index e9623b99a6..cf88868a8a 100644
--- a/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat/layer_infer/transformer_layer_infer.py
@@ -106,7 +106,6 @@ def _context_attention_kernel(
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
- infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
infer_state.b_req_idx,
infer_state.b_q_start_loc,
infer_state.b_seq_len,
diff --git a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
index 271c9592c3..9a791ff221 100644
--- a/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/neo_chat_moe/layer_infer/transformer_layer_infer.py
@@ -104,7 +104,6 @@ def _context_attention_kernel(
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
- infer_state.position_ids[0], # [0,0,1,2,3,3,3,4]
infer_state.b_req_idx,
infer_state.b_q_start_loc,
infer_state.b_seq_len,
diff --git a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
index 8cefe486c0..1016e5a3c1 100644
--- a/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
+++ b/lightllm/models/neo_chat_moe/triton_kernel/context_attention_fwd_neo.py
@@ -13,7 +13,6 @@ def _fwd_kernel(
V,
sm_scale,
Out,
- position_ids, # 1D: packed like Q (only NEW tokens), length == Q.shape[0]
B_Start_Loc,
B_Seqlen,
Req_to_tokens,
@@ -79,7 +78,6 @@ def _fwd_kernel(
acc = tl.zeros([BLOCK_M, V_HEAD_DIM], dtype=tl.float32)
# absolute q positions in the request
q_pos = prompt_cache_len + offs_m # [M]
- q_gid = tl.load(position_ids + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=-1)
q_image_end = tl.load(b_image_token_end + cur_batch_in_all_start_index + offs_m, mask=q_valid, other=0)
causal_end = tl.minimum(prompt_cache_len + block_start_loc + BLOCK_M, total_len)
@@ -91,13 +89,6 @@ def _fwd_kernel(
k_pos = start_n + offs_n # [N]
k_valid = k_pos < block_end_loc
- k_in_new = k_pos >= prompt_cache_len
- k_rel = k_pos - prompt_cache_len
- k_gid = tl.load(
- position_ids + cur_batch_in_all_start_index + k_rel,
- mask=k_valid & k_in_new,
- other=-2,
- )
# map logical pos -> mem_index (for K/V)
kv_loc = tl.load(
@@ -113,12 +104,7 @@ def _fwd_kernel(
qk += tl.dot(q, k)
causal_mask = q_pos[:, None] >= k_pos[None, :]
- image_mask = (
- (q_image_end[:, None] > 0)
- & k_in_new[None, :]
- & (k_pos[None, :] < q_image_end[:, None])
- & (q_gid[:, None] == k_gid[None, :])
- )
+ image_mask = k_pos[None, :] < q_image_end[:, None]
mask = (causal_mask | image_mask) & k_valid[None, :]
qk = tl.where(mask, qk * sm_scale, -1.0e8)
@@ -157,7 +143,6 @@ def context_attention_fwd_neo(
k,
v,
o,
- position_ids, # 1D packed like q (only NEW tokens)
b_req_idx,
b_start_loc,
b_seq_len,
@@ -166,9 +151,6 @@ def context_attention_fwd_neo(
req_to_token_indexs,
b_image_token_end,
):
- # minimal safety: position_ids must cover packed q rows
- assert position_ids.numel() >= q.shape[0], (position_ids.numel(), q.shape[0])
-
BLOCK_M = 128 if not is_tesla() else 64
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
@@ -191,7 +173,6 @@ def context_attention_fwd_neo(
v,
sm_scale,
o,
- position_ids,
b_start_loc,
b_seq_len,
req_to_token_indexs,
@@ -227,7 +208,6 @@ def reference_attention(
q,
k,
v,
- position_ids_q, # 1D packed like q (only NEW tokens)
b_image_token_end,
b_req_idx,
b_start_loc,
@@ -253,7 +233,6 @@ def reference_attention(
q_start = int(b_start_loc[b].item())
q_blk = q[q_start : q_start + new_len] # [M, Hq, D]
- gid_new = position_ids_q[q_start : q_start + new_len].to(torch.int64) # [M]
image_end_new = b_image_token_end[q_start : q_start + new_len].to(torch.int64) # [M]
# gather K/V for full request by logical pos -> mem_index
@@ -272,22 +251,7 @@ def reference_attention(
# build allow mask:
# causal always
allow = k_pos[None, :] <= q_pos[:, None]
-
- # full-attn only inside NEW part by gid
- # compare only when k_pos in NEW
- k_in_new = k_pos >= prompt_len
- k_rel = (k_pos - prompt_len).clamp_min(0) # [L]
- # map k_rel to gid_new, but only valid where k_in_new
- k_gid = torch.empty((total_len,), device=device, dtype=torch.int64)
- k_gid[:] = 10 ** 12 + k_pos # never equal to gid_new
- k_gid[k_in_new] = gid_new[k_rel[k_in_new]]
-
- allow = allow | (
- (image_end_new[:, None] > 0)
- & k_in_new[None, :]
- & (k_pos[None, :] < image_end_new[:, None])
- & (gid_new[q_pos - prompt_len][:, None] == k_gid[None, :])
- )
+ allow = allow | (k_pos[None, :] < image_end_new[:, None])
# scores: [Hq, M, L]
q_t = q_blk.permute(1, 0, 2).to(torch.float32) # [Hq, M, D]
@@ -355,25 +319,18 @@ def make_test_case(
req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32)
p += L
- # position_ids_q: only NEW tokens, packed like q
- position_ids_q = torch.empty((sum_q,), device=device, dtype=torch.int32)
b_image_token_end = torch.zeros((sum_q,), device=device, dtype=torch.int32)
for b in range(batch):
M = int(new_lens[b].item())
P = int(prompt_lens[b].item())
start = int(b_start_loc[b].item())
- gid = torch.arange(P, P + M, device=device, dtype=torch.int32)
-
# make one repeated block inside NEW part to simulate image tokens
if M >= 4 and torch.rand((), device=device).item() > 0.3:
s = int(torch.randint(0, M - 2, (1,), device=device).item())
e = min(M, s + 3)
- gid[s:e] = gid[s]
b_image_token_end[start + s : start + e] = P + e
- position_ids_q[start : start + M] = gid
-
q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype)
k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype)
v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype)
@@ -384,7 +341,6 @@ def make_test_case(
k,
v,
o,
- position_ids_q,
b_image_token_end,
b_req_idx,
b_start_loc,
@@ -401,7 +357,6 @@ def check_once(device="cuda", dtype=torch.float16, seed=0):
k,
v,
o,
- position_ids_q,
b_image_token_end,
b_req_idx,
b_start_loc,
@@ -416,7 +371,6 @@ def check_once(device="cuda", dtype=torch.float16, seed=0):
k,
v,
o,
- position_ids_q,
b_req_idx,
b_start_loc,
b_seq_len,
@@ -430,7 +384,6 @@ def check_once(device="cuda", dtype=torch.float16, seed=0):
q,
k,
v,
- position_ids_q,
b_image_token_end,
b_req_idx,
b_start_loc,
diff --git a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
index 730d4263ee..5cb8bd58ff 100644
--- a/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
+++ b/unit_tests/common/basemodel/triton_kernel/att/prefill_att/test_context_attention_fwd_neo.py
@@ -5,7 +5,7 @@
batch element we gather K/V for the whole request (prompt + new tokens) via
``req_to_token_indexs`` and apply::
- allow[m, k] = (k <= q_pos[m]) OR same_image_group[m, k]
+ allow[m, k] = (k <= q_pos[m]) OR (k < image_end[m])
i.e. normal queries are causal, and image-token queries can only see future
tokens from the same image span. If the Triton kernel disagrees with this
@@ -35,7 +35,6 @@ def torch_reference_context_attention_neo(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
- position_ids_0: torch.Tensor,
b_req_idx: torch.Tensor,
b_start_loc: torch.Tensor,
b_seq_len: torch.Tensor,
@@ -62,7 +61,6 @@ def torch_reference_context_attention_neo(
q_start = int(b_start_loc[b].item())
q_blk = q[q_start : q_start + new] # [M, Hq, D]
- q_gid = position_ids_0[q_start : q_start + new].to(torch.int64)
q_image_end = b_image_token_end[q_start : q_start + new].to(torch.int64)
token_locs = req_to_token_indexs[req_idx, :total].to(torch.int64)
@@ -72,16 +70,7 @@ def torch_reference_context_attention_neo(
q_pos = torch.arange(prompt, total, device=device, dtype=torch.int64) # [M]
k_pos = torch.arange(0, total, device=device, dtype=torch.int64) # [total]
causal = k_pos[None, :] <= q_pos[:, None]
-
- k_gid = torch.full((total,), -1, device=device, dtype=torch.int64)
- k_gid[prompt:total] = q_gid
- same_image = (
- (q_image_end[:, None] > 0)
- & (k_pos[None, :] >= prompt)
- & (k_pos[None, :] < q_image_end[:, None])
- & (q_gid[:, None] == k_gid[None, :])
- )
- allow = causal | same_image
+ allow = causal | (k_pos[None, :] < q_image_end[:, None])
out_blk = torch.empty_like(q_blk)
for h in range(Hq):
@@ -149,12 +138,10 @@ def _build_inputs(
p += L
b_image_token_end = torch.zeros(sum_new, dtype=torch.int32)
- position_ids_0 = torch.empty(sum_new, dtype=torch.int32)
for i in range(batch):
M = int(new_lens[i].item())
P = int(prompt_lens[i].item())
start_pack = int(b_start_loc[i].item())
- position_ids_0[start_pack : start_pack + M] = torch.arange(P, P + M, dtype=torch.int32)
if M < 2:
continue
if torch.rand((), generator=g).item() > image_prob:
@@ -172,9 +159,7 @@ def _build_inputs(
break
span_len = int(torch.randint(1, max_span_len + 1, (1,), generator=g).item())
e_rel = s_rel + span_len
- image_gid = P + s_rel
image_end = P + e_rel
- position_ids_0[start_pack + s_rel : start_pack + e_rel] = image_gid
b_image_token_end[start_pack + s_rel : start_pack + e_rel] = image_end
cursor = e_rel
@@ -191,7 +176,6 @@ def _build_inputs(
k=k,
v=v,
o=o,
- position_ids_0=position_ids_0.to(device),
b_req_idx=b_req_idx.to(device),
b_start_loc=b_start_loc.to(device),
b_seq_len=b_seq_len.to(device),
@@ -258,7 +242,6 @@ def _run_case(
inputs["k"],
inputs["v"],
inputs["o"],
- inputs["position_ids_0"],
inputs["b_req_idx"],
inputs["b_start_loc"],
inputs["b_seq_len"],
@@ -273,7 +256,6 @@ def _run_case(
inputs["q"],
inputs["k"],
inputs["v"],
- inputs["position_ids_0"],
inputs["b_req_idx"],
inputs["b_start_loc"],
inputs["b_seq_len"],