Skip to content

Commit 9ea753a

Browse files
JimHsiungliutongxuan
authored andcommitted
feat: support glm4v_moe for npu.
1 parent 0c38d0d commit 9ea753a

File tree

10 files changed

+1567
-185
lines changed

10 files changed

+1567
-185
lines changed

xllm/core/layers/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ cc_library(
5656
HDRS
5757
column_parallel_linear.h
5858
deepseek_v2_decoder_layer.h
59+
glm4_vision_encode_layer.h
5960
llama_decoder_layer.h
6061
multi_head_attention.h
6162
qwen2_decoder_layer.h
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#if defined(USE_NPU)
19+
#include "npu/npu_glm4_vision_encoder_layer_impl.h"
20+
#endif
21+
22+
namespace xllm {
23+
namespace layer {
24+
25+
#if defined(USE_NPU)
26+
class Glm4VisionEncoderLayer
27+
: public torch::nn::ModuleHolder<NpuGlm4VisionEncoderLayerImpl> {
28+
public:
29+
using torch::nn::ModuleHolder<NpuGlm4VisionEncoderLayerImpl>::ModuleHolder;
30+
using Impl __attribute__((__unused__)) = NpuGlm4VisionEncoderLayerImpl;
31+
32+
Glm4VisionEncoderLayer(const ModelContext& context)
33+
: ModuleHolder(std::make_shared<NpuGlm4VisionEncoderLayerImpl>(context)) {
34+
}
35+
};
36+
#endif
37+
38+
} // namespace layer
39+
} // namespace xllm

xllm/core/layers/npu/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ cc_library(
1919
npu_base_layer.h
2020
npu_column_parallel_linear_impl.h
2121
npu_glm4_moe_decoder_layer.h
22+
npu_glm4_vision_encoder_layer_impl.h
2223
npu_deepseek_v2_decoder_layer_impl.h
2324
npu_llama_decoder_layer_impl.h
2425
npu_qwen2_decoder_layer_impl.h
@@ -39,6 +40,7 @@ cc_library(
3940
npu_base_layer.cpp
4041
npu_column_parallel_linear_impl.cpp
4142
npu_glm4_moe_decoder_layer.cpp
43+
npu_glm4_vision_encoder_layer_impl.cpp
4244
npu_deepseek_v2_decoder_layer_impl.cpp
4345
npu_llama_decoder_layer_impl.cpp
4446
npu_qwen2_decoder_layer_impl.cpp
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
// copy from qwen3 vl, please follow its modifications
17+
#include "npu_glm4_vision_encoder_layer_impl.h"
18+
19+
#include <glog/logging.h>
20+
#include <mstx/ms_tools_ext.h>
21+
22+
#include <iostream>
23+
#include <map>
24+
25+
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
26+
#include "torch_npu/csrc/core/npu/NPUException.h"
27+
#include "xllm_kernels/models/glm4v/glm4v_encoder.h"
28+
29+
namespace xllm {
30+
namespace layer {
31+
32+
enum Glm4VisionEncoderLayerTensorId : int {
33+
IN_INPUT_NORM_WEIGHT = 0,
34+
IN_POST_NORM_WEIGHT,
35+
IN_QKV_WEIGHT,
36+
IN_ATTN_PROJ_WEIGHT,
37+
IN_LINEAR_GATE_UP_WEIGHT,
38+
IN_LINEAR_DOWN_WEIGHT,
39+
IN_LINEAR_UP_WEIGHT,
40+
IN_LINEAR_GATE_WEIGHT
41+
};
42+
43+
const uint64_t WEIGHT_COUNT_PER_LAYER = 8;
44+
45+
static std::vector<std::pair<int, std::string>> WEIGHT_MAPPING = {
46+
{IN_INPUT_NORM_WEIGHT, "norm1.weight"},
47+
{IN_POST_NORM_WEIGHT, "norm2.weight"},
48+
{IN_QKV_WEIGHT, "attn.qkv.weight"},
49+
{IN_ATTN_PROJ_WEIGHT, "attn.proj.weight"},
50+
{IN_LINEAR_GATE_WEIGHT, "mlp.gate_proj.weight"},
51+
{IN_LINEAR_UP_WEIGHT, "mlp.up_proj.weight"},
52+
{IN_LINEAR_DOWN_WEIGHT, "mlp.down_proj.weight"}};
53+
54+
// {weight,dim}
55+
// IN_QKV_WEIGHT SHARD artificially in merge_loaded_weights
56+
static std::map<int, int> WEIGHT_SHARD = {{IN_ATTN_PROJ_WEIGHT, 1},
57+
{IN_LINEAR_UP_WEIGHT, 0},
58+
{IN_LINEAR_GATE_WEIGHT, 0},
59+
{IN_LINEAR_DOWN_WEIGHT, 1}};
60+
// TODO: check shape with atb log -- HW pxy
61+
62+
void NpuGlm4VisionEncoderLayerImpl::param_from_args(
63+
atb_speed::glm::VisionEncoderLayerParam& param,
64+
const ModelArgs& args,
65+
const ParallelArgs& parallel_args) {
66+
param.isBF16 = args.dtype() == "bfloat16";
67+
param.supportLcoc = false;
68+
param.supportLora = false;
69+
param.loraEnableGMM = false;
70+
param.enableLogN = false;
71+
param.backend = "hccl";
72+
param.rank = parallel_args.rank();
73+
param.worldSize = parallel_args.world_size();
74+
75+
param.quantType = 0;
76+
param.quantGroupSize = 64;
77+
78+
param.numAttentionHeadsPerRank =
79+
args.mm_num_attention_heads() / param.worldSize;
80+
param.hiddenSizePerAttentionHead =
81+
args.mm_hidden_size() / args.mm_num_attention_heads();
82+
std::optional<long int> optionalValue = args.mm_num_attention_heads();
83+
param.numKeyValueHeadsPerRank =
84+
static_cast<int>(optionalValue.value()) / param.worldSize;
85+
86+
param.rmsNormEps = args.rms_norm_eps();
87+
}
88+
89+
NpuGlm4VisionEncoderLayerImpl::NpuGlm4VisionEncoderLayerImpl(
90+
const ModelContext& context)
91+
: NpuBaseLayer(context) {
92+
auto model_args = context.get_model_args();
93+
auto parallel_args = context.get_parallel_args();
94+
auto options = context.get_tensor_options();
95+
param_from_args(encode_param_, model_args, parallel_args);
96+
at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER);
97+
atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER);
98+
dtype_ = c10::typeMetaToScalarType(options.dtype());
99+
device_id_ = options.device().index();
100+
placeholder_ =
101+
atb_speed::Utils::AtTensor2Tensor(torch::zeros({1}).to(device_).to(
102+
dtype_)); // seems not to be used -- HW pxy
103+
at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_);
104+
for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
105+
at_weight_tensors_[i] = torch::zeros({1}).to(options);
106+
}
107+
}
108+
109+
void NpuGlm4VisionEncoderLayerImpl::verify_loaded_weights() const {
110+
for (const auto& [index, name] : WEIGHT_MAPPING) {
111+
CHECK(at_weight_tensors_[index].sizes() != std::vector<int64_t>({1}))
112+
<< "weight is not loaded for " << name;
113+
}
114+
}
115+
116+
void NpuGlm4VisionEncoderLayerImpl::merge_loaded_weights() {
117+
if (encode_param_.worldSize > 1) {
118+
// spilt pack qkv weight when enable tp
119+
get_weights_col_packed_qkv();
120+
}
121+
122+
at_weight_tensors_[IN_LINEAR_GATE_UP_WEIGHT] =
123+
torch::cat({at_weight_tensors_[IN_LINEAR_GATE_WEIGHT],
124+
at_weight_tensors_[IN_LINEAR_UP_WEIGHT]},
125+
0);
126+
at_weight_tensors_[IN_LINEAR_GATE_WEIGHT] = torch::empty({}, device_);
127+
at_weight_tensors_[IN_LINEAR_UP_WEIGHT] = torch::empty({}, device_);
128+
129+
c10_npu::NPUCachingAllocator::emptyCache();
130+
for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
131+
atb_weight_tensors_[i] =
132+
atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]);
133+
}
134+
135+
init_layer();
136+
}
137+
138+
// tp spilt weight
139+
void NpuGlm4VisionEncoderLayerImpl::get_weights_col_packed_qkv() {
140+
int rank = encode_param_.rank;
141+
int worldSize = encode_param_.worldSize;
142+
// split qkv weight
143+
auto qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0);
144+
// get local weight and merge
145+
auto new_qkv_weight = torch::cat({(qkv_weight[0].chunk(worldSize, 0))[rank],
146+
(qkv_weight[1].chunk(worldSize, 0))[rank],
147+
(qkv_weight[2].chunk(worldSize, 0))[rank]},
148+
0);
149+
at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight;
150+
}
151+
152+
void NpuGlm4VisionEncoderLayerImpl::load_state_dict(
153+
const StateDict& state_dict) {
154+
for (const auto& [index, name] : WEIGHT_MAPPING) {
155+
if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) {
156+
set_weight(state_dict, name, index, WEIGHT_SHARD[index]);
157+
} else {
158+
set_weight(state_dict, name, index);
159+
}
160+
}
161+
}
162+
163+
int64_t NpuGlm4VisionEncoderLayerImpl::init_layer() {
164+
name_ = "glm4_vision_encoder_layer";
165+
model_name_ = "glm4v";
166+
CHECK_OPERATION_STATUS_RETURN(init_node(encode_node_, encode_param_));
167+
return atb::NO_ERROR;
168+
}
169+
170+
int64_t NpuGlm4VisionEncoderLayerImpl::init_node(
171+
atb_speed::Model::Node& node,
172+
atb_speed::glm::VisionEncoderLayerParam& param) {
173+
atb::Operation* operation = nullptr;
174+
atb_speed::glm::Glm4v_EncoderLayer(param, &operation);
175+
node.operation.reset(operation);
176+
if (node.operation == nullptr) {
177+
LOG(ERROR) << "node.operation is null";
178+
return -1;
179+
}
180+
if (node.operation->GetInputNum() < 1) {
181+
LOG(ERROR) << "Can not resize number which is smaller than 1";
182+
return -1;
183+
}
184+
node.inTensors.resize(node.operation->GetInputNum());
185+
node.outTensors.resize(1);
186+
size_t inTensorId = 1;
187+
188+
for (size_t weightTensorId = 0; weightTensorId < WEIGHT_COUNT_PER_LAYER;
189+
++weightTensorId) {
190+
node.inTensors.at(weightTensorId) = &atb_weight_tensors_[weightTensorId];
191+
}
192+
193+
node.variantPack.inTensors.reserve(node.inTensors.size());
194+
node.variantPack.inTensors.resize(node.inTensors.size());
195+
node.variantPack.outTensors.reserve(1);
196+
node.variantPack.outTensors.resize(1);
197+
return atb::NO_ERROR;
198+
}
199+
200+
torch::Tensor NpuGlm4VisionEncoderLayerImpl::forward(
201+
torch::Tensor& x,
202+
torch::Tensor& cos_pos,
203+
torch::Tensor& sin_pos,
204+
torch::Tensor& cu_seqlen,
205+
std::vector<int>& cu_seqlen_vec,
206+
ModelInputParams& input_params,
207+
int node_id,
208+
aclrtEvent* event,
209+
std::atomic<bool>* event_flag) {
210+
atb::Status st;
211+
212+
build_node_variant_pack(encode_node_,
213+
x,
214+
cos_pos,
215+
sin_pos,
216+
cu_seqlen,
217+
cu_seqlen_vec,
218+
input_params,
219+
true);
220+
// mstxRangeEnd(id);
221+
st = execute_node(encode_node_, node_id);
222+
LOG_IF(FATAL, st != 0) << model_name_
223+
<< "excute encode layer fail, error code: " << st;
224+
return x;
225+
}
226+
227+
void NpuGlm4VisionEncoderLayerImpl::build_node_variant_pack(
228+
atb_speed::Model::Node& node,
229+
torch::Tensor& x,
230+
torch::Tensor& cos_pos,
231+
torch::Tensor& sin_pos,
232+
torch::Tensor& cu_seqlen,
233+
std::vector<int>& cu_seqlen_vec,
234+
ModelInputParams& input_params,
235+
bool is_prefill) {
236+
internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x);
237+
238+
auto actual_weight_num = WEIGHT_COUNT_PER_LAYER - 2;
239+
for (size_t i = 0; i < actual_weight_num; ++i) {
240+
CHECK_THROW(node.inTensors.at(i) == nullptr,
241+
model_name_ << "inTensor " << i << "is NULL");
242+
node.variantPack.inTensors.at(i) = *node.inTensors.at(i);
243+
// LOG(INFO) << model_name_ << "inTensors[" << i << "]:"
244+
// << atb_speed::TensorUtil::TensorToString(
245+
// node.variantPack.inTensors.at(i));
246+
}
247+
node.variantPack.inTensors.at(actual_weight_num) = internal_tensors_;
248+
node.variantPack.inTensors.at(actual_weight_num + 1) =
249+
atb_speed::Utils::AtTensor2Tensor(cos_pos);
250+
node.variantPack.inTensors.at(actual_weight_num + 2) =
251+
atb_speed::Utils::AtTensor2Tensor(sin_pos);
252+
node.variantPack.inTensors.at(actual_weight_num + 3) =
253+
atb_speed::Utils::AtTensor2Tensor(cu_seqlen);
254+
node.variantPack.inTensors.at(actual_weight_num + 3).hostData =
255+
cu_seqlen_vec.data();
256+
257+
node.variantPack.outTensors.at(0) = internal_tensors_;
258+
}
259+
260+
} // namespace layer
261+
} // namespace xllm

0 commit comments

Comments
 (0)