diff --git a/xllm/core/layers/npu/CMakeLists.txt b/xllm/core/layers/npu/CMakeLists.txt index a82e536be..52b215548 100644 --- a/xllm/core/layers/npu/CMakeLists.txt +++ b/xllm/core/layers/npu/CMakeLists.txt @@ -29,6 +29,20 @@ cc_library( npu_siglip_encoder_layer_impl.h ../common/rotary_embedding_util.h rotary_embedding.h + loader/qwen3_decoder_loader.h + loader/qwen2_decoder_loader.h + loader/qwen3_moe_decoder_loader.h + loader/word_embedding_loader.h + loader/lm_head_loader.h + loader/column_parallel_linear_loader.h + loader/deepseek_v2_decoder_loader.h + loader/glm4_moe_decoder_loader.h + loader/llama_decoder_loader.h + loader/qwen2dot5_vision_encoder_loader.h + loader/qwen3_vision_encoder_loader.h + loader/rms_norm_loader.h + loader/siglip_encoder_loader.h + loader/base_loader.h SRCS npu_word_embedding_impl.cpp npu_pos_embedding_impl.cpp @@ -53,6 +67,20 @@ cc_library( npu_siglip_encoder_layer_impl.cpp ../common/rotary_embedding_util.cpp rotary_embedding.cpp + loader/qwen3_decoder_loader.cpp + loader/qwen2_decoder_loader.cpp + loader/qwen3_moe_decoder_loader.cpp + loader/word_embedding_loader.cpp + loader/lm_head_loader.cpp + loader/column_parallel_linear_loader.cpp + loader/deepseek_v2_decoder_loader.cpp + loader/glm4_moe_decoder_loader.cpp + loader/llama_decoder_loader.cpp + loader/qwen2dot5_vision_encoder_loader.cpp + loader/qwen3_vision_encoder_loader.cpp + loader/rms_norm_loader.cpp + loader/siglip_encoder_loader.cpp + loader/base_loader.cpp DEPS "-Wl,--whole-archive" "-Wl,--no-whole-archive" diff --git a/xllm/core/layers/npu/loader/base_loader.cpp b/xllm/core/layers/npu/loader/base_loader.cpp new file mode 100644 index 000000000..5d62fbbd5 --- /dev/null +++ b/xllm/core/layers/npu/loader/base_loader.cpp @@ -0,0 +1,146 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "base_loader.h" + +namespace xllm { +namespace layer { + +BaseLoader::BaseLoader(uint64_t weight_count, const ModelContext& context) + : weight_count_(weight_count), + parallel_args_(context.get_parallel_args()), + device_(context.get_tensor_options().device()) { + auto quant_args = context.get_quant_args(); + if (!quant_args.quantize_type().empty()) { + quantize_type_ = quant_args.quantize_type(); + } + + if (!quant_args.torch_dtype().empty()) { + torch_dtype_ = quant_args.torch_dtype(); + } + + dp_size_ = parallel_args_.dp_size(); + dp_local_tp_size_ = parallel_args_.world_size() / dp_size_; + dp_rank_ = parallel_args_.rank() / dp_local_tp_size_; + CHECK_EQ(parallel_args_.world_size(), dp_size_ * dp_local_tp_size_); + dp_local_tp_rank_ = parallel_args_.rank() % dp_local_tp_size_; + + at_weight_tensors_.resize(weight_count_); +} + +void BaseLoader::set_weight(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position) { + for (const auto& [name, tensor] : state_dict) { + if (absl::EndsWith(name, tensor_name)) { + at::Tensor mutable_tensor = tensor; + correct_tensor_dtype(mutable_tensor, tensor_name); + at_weight_tensors_[weight_position] = mutable_tensor.to(device_); + } + } +} + +void BaseLoader::set_weight(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position, + int dim) { + for (const auto& [name, tensor] : state_dict) { + if (absl::EndsWith(name, tensor_name)) { + if (parallel_args_.world_size() <= 1) { + at::Tensor mutable_tensor = tensor; + correct_tensor_dtype(mutable_tensor, tensor_name); + at_weight_tensors_[weight_position] = mutable_tensor.to(device_); + } else { + at_weight_tensors_[weight_position] = + state_dict + .get_sharded_tensor(tensor_name, + /*dim=*/dim, + /*rank=*/parallel_args_.rank(), + /*world_size=*/parallel_args_.world_size()) + .to(device_); + correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name); + } + } + } +} + +void BaseLoader::set_weight(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position, + int dim, + int rank, + int world_size) { + for (const auto& [name, tensor] : state_dict) { + if (absl::EndsWith(name, tensor_name)) { + if (world_size <= 1) { + at::Tensor mutable_tensor = tensor; + correct_tensor_dtype(mutable_tensor, tensor_name); + at_weight_tensors_[weight_position] = mutable_tensor.to(device_); + } else { + at_weight_tensors_[weight_position] = + state_dict + .get_sharded_tensor(tensor_name, + /*dim=*/dim, + /*rank=*/rank, + /*world_size=*/world_size) + .to(device_); + correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name); + } + } + } +} + +void BaseLoader::correct_tensor_dtype(torch::Tensor& tensor, + const std::string& tensorName) { + if (absl::EndsWith(tensorName, "deq_scale") && + (torch_dtype_.compare("bfloat16") == 0)) { + return; + } + + if (tensor.dtype() != torch::kInt8 && tensor.dtype() != torch::kInt32 && + tensor.dtype() != torch::kInt64) { + torch::Dtype dtype = string2dtype(torch_dtype_); + tensor = tensor.to(dtype); + } +} + +torch::Dtype BaseLoader::string2dtype(const std::string& dtype_str) { + if (dtype_str.compare("float16") == 0) { + return torch::kFloat16; + } else if (dtype_str.compare("bfloat16") == 0) { + return torch::kBFloat16; + } else if (dtype_str.compare("float32") == 0) { + return torch::kFloat32; + } else if (dtype_str.compare("float64") == 0) { + return torch::kFloat64; + } else if (dtype_str.compare("int8") == 0) { + return torch::kInt8; + } else if (dtype_str.compare("int16") == 0) { + return torch::kInt16; + } else if (dtype_str.compare("int32") == 0) { + return torch::kInt32; + } else if (dtype_str.compare("int64") == 0) { + return torch::kInt64; + } else if (dtype_str.compare("uint8") == 0) { + return torch::kUInt8; + } else if (dtype_str.compare("bool") == 0) { + return torch::kBool; + } + + LOG(FATAL) << "Unsupported dtype string: " << dtype_str; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/base_loader.h b/xllm/core/layers/npu/loader/base_loader.h new file mode 100644 index 000000000..3429c571c --- /dev/null +++ b/xllm/core/layers/npu/loader/base_loader.h @@ -0,0 +1,102 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "framework/eplb/expert_buffer_manager.h" +#include "framework/kv_cache/kv_cache.h" +#include "framework/model/model_input_params.h" +#include "framework/model_context.h" +#include "framework/state_dict/state_dict.h" +#include "xllm_kernels/pytorch/atb_torch/core/include/base_operation.h" +#include "xllm_kernels/pytorch/atb_torch/core/include/graph_operation.h" + +namespace xllm { +namespace layer { + +class BaseLoader { + public: + BaseLoader(uint64_t weight_count, const ModelContext& context); + virtual ~BaseLoader() = default; + + virtual void load_state_dict(const StateDict& state_dict) {}; + virtual void verify_loaded_weights() const {}; + virtual void verify_loaded_weights(const std::string& prefix) const {}; + virtual void merge_loaded_weights() {}; + virtual void resize_experts_weights(int num_of_device_experts) {}; + torch::Dtype string2dtype(const std::string& dtype_str); + + void correct_tensor_dtype(torch::Tensor& tensor, + const std::string& tensorName); + + std::vector& get_at_weight_tensors() { + return at_weight_tensors_; + } + + std::unordered_map>& + get_experts_weight_tensors() { + return experts_weights_; + } + + std::unique_ptr& get_expert_shared_buffer() { + return shared_buffer_; + } + + std::vector& get_device_expert_list() { return device_expert_list_; } + + atb_torch::TorchTensorMap& get_weights_map() { return weights_map_; } + + protected: + uint64_t weight_count_; + xllm::ParallelArgs parallel_args_; + std::string quantize_type_; + std::string torch_dtype_; + torch::ScalarType dtype_; + torch::TensorOptions options_; + std::vector at_weight_tensors_; + std::unique_ptr shared_buffer_ = nullptr; + std::unordered_map shared_experts_weights_; + std::unordered_map> experts_weights_; + std::vector device_expert_list_; + atb_torch::TorchTensorMap weights_map_; + + at::Device device_; + int32_t dp_size_; + int32_t dp_local_tp_size_; + int32_t dp_rank_; + int32_t dp_local_tp_rank_; + + void set_weight(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position); + + void set_weight(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position, + int dim); + + void set_weight(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position, + int dim, + int rank, + int world_size); +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp b/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp new file mode 100644 index 000000000..35836e916 --- /dev/null +++ b/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp @@ -0,0 +1,47 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "column_parallel_linear_loader.h" + +namespace xllm { +namespace layer { + +ColumParallelLinearLoader::ColumParallelLinearLoader( + uint64_t weight_count, + const ModelContext& context) + : BaseLoader(weight_count, context) { + auto options = context.get_tensor_options(); + dtype_ = torch::typeMetaToScalarType(options.dtype()); + at_weight_tensors_[0] = torch::zeros({1}).to(options); +} + +void ColumParallelLinearLoader::load_state_dict(const StateDict& state_dict) { + if (dp_size_ > 1) { + set_weight( + state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_); + } else { + set_weight(state_dict, "weight", 0, 0); + } + at_weight_tensors_[0] = at_weight_tensors_[0].to(dtype_); +} + +void ColumParallelLinearLoader::verify_loaded_weights( + const std::string& weight_str) const { + CHECK(at_weight_tensors_[0].sizes() != std::vector({1})) + << "weight is not loaded for " << weight_str; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/column_parallel_linear_loader.h b/xllm/core/layers/npu/loader/column_parallel_linear_loader.h new file mode 100644 index 000000000..5a52fa775 --- /dev/null +++ b/xllm/core/layers/npu/loader/column_parallel_linear_loader.h @@ -0,0 +1,28 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "base_loader.h" + +namespace xllm { +namespace layer { +class ColumParallelLinearLoader : public BaseLoader { + public: + ColumParallelLinearLoader(uint64_t weight_count, const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights(const std::string& weight_str) const override; +}; +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/deepseek_v2_decoder_loader.cpp b/xllm/core/layers/npu/loader/deepseek_v2_decoder_loader.cpp new file mode 100644 index 000000000..f62062b6f --- /dev/null +++ b/xllm/core/layers/npu/loader/deepseek_v2_decoder_loader.cpp @@ -0,0 +1,994 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "deepseek_v2_decoder_loader.h" + +#include + +namespace xllm { +namespace layer { + +enum DecoderLayerTensorId : int { + IN_INPUT_NORM_WEIGHT = 0, + IN_INPUT_NORM_BIAS = 1, + IN_INPUT_NORM_NEW_WEIGHT = 2, + IN_INPUT_NORM_NEW_BIAS = 3, + + IN_Q_PROJ_A_WEIGHT = 4, + IN_Q_PROJ_A_BIAS = 5, + IN_Q_PROJ_A_DESCALE = 6, + IN_Q_PROJ_A_OFFSET = 7, + IN_Q_PROJ_A_SCALE = 8, + IN_Q_PROJ_A_COMPRESS_IDX = 9, + IN_Q_PROJ_A_LAYERNORM_WEIGHT = 10, + IN_Q_PROJ_A_LAYERNORM_BIAS = 11, + + IN_Q_PROJ_B_WEIGHT = 12, + IN_Q_PROJ_B_BIAS = 13, + IN_Q_PROJ_B_DESCALE = 14, + IN_Q_PROJ_B_OFFSET = 15, + IN_Q_PROJ_B_SCALE = 16, + IN_Q_PROJ_B_COMPRESS_IDX = 17, + + IN_KV_PROJ_WITH_MQA_WEIGHT = 18, + IN_KV_PROJ_WITH_MQA_BIAS = 19, + IN_KV_PROJ_WITH_MQA_DESCALE = 20, + IN_KV_PROJ_WITH_MQA_OFFSET = 21, + IN_KV_PROJ_WITH_MQA_SCALE = 22, + IN_KV_PROJ_WITH_MQA_COMPRESS_IDX = 23, + + IN_KV_PROJ_A_LAYERNORM_WEIGHT = 24, + IN_KV_PROJ_A_LAYERNORM_BIAS = 25, + + IN_K_PROJ_B_FOR_Q_WEIGHT = 26, + IN_K_PROJ_B_FOR_Q_BIAS = 27, + IN_K_PROJ_B_FOR_Q_DESCALE = 28, + IN_K_PROJ_B_FOR_Q_OFFSET = 29, + IN_K_PROJ_B_FOR_Q_SCALE = 30, + IN_K_PROJ_B_FOR_Q_COMPRESS_IDX = 31, + + IN_V_PROJ_B_FOR_O_WEIGHT = 32, + IN_V_PROJ_B_FOR_O_BIAS = 33, + IN_V_PROJ_B_FOR_O_DESCALE = 34, + IN_V_PROJ_B_FOR_O_OFFSET = 35, + IN_V_PROJ_B_FOR_O_SCALE = 36, + IN_V_PROJ_B_FOR_O_COMPRESS_IDX = 37, + + IN_ATTENTION_OUT_WEIGHT = 38, + IN_ATTENTION_OUT_BIAS = 39, + IN_ATTENTION_OUT_DESCALE = 40, + IN_ATTENTION_OUT_OFFSET = 41, + IN_ATTENTION_OUT_SCALE = 42, + IN_ATTENTION_OUT_COMPRESS_IDX = 43, + + IN_SELFATTENTION_OUT_NORM_WEIGHT = 44, + IN_SELFATTENTION_OUT_NORM_BIAS = 45, + IN_SELFATTENTION_OUT_NEW_NORM_WEIGHT = 46, + IN_SELFATTENTION_OUT_NEW_NORM_BIAS = 47, + + IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT = 48, + IN_MLP_GATEUP_BIAS_SHARED_EXPERT = 49, + IN_MLP_GATEUP_DESCALE_SHARED_EXPERT = 50, + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT = 51, + IN_MLP_GATEUP_SCALE_SHARED_EXPERT = 52, + IN_MLP_GATEUP_COMPRESS_IDX_SHARED_EXPERT = 53, + + IN_MLP_DOWN_WEIGHT_SHARED_EXPERT = 54, + IN_MLP_DOWN_BIAS_SHARED_EXPERT = 55, + IN_MLP_DOWN_DESCALE_SHARED_EXPERT = 56, + IN_MLP_DOWN_OFFSET_SHARED_EXPERT = 57, + IN_MLP_DOWN_SCALE_SHARED_EXPERT = 58, + IN_MLP_DOWN_COMPRESS_IDX_SHARED_EXPERT = 59, + + IN_SHARED_EXPERT_GATE_WEIGHT = 60, + IN_SHARED_EXPERT_GATE_BIAS = 61, + IN_SHARED_EXPERT_GATE_DESCALE = 62, + IN_SHARED_EXPERT_GATE_OFFSET = 63, + IN_SHARED_EXPERT_GATE_SCALE = 64, + IN_SHARED_EXPERT_GATE_COMPRESS_IDX = 65, + + IN_BLOCK_SPARSE_MOE_GATE_WEIGHT = 66, + IN_BLOCK_SPARSE_MOE_GATE_BIAS = 67, + IN_BLOCK_SPARSE_MOE_GATE_DESCALE = 68, + IN_BLOCK_SPARSE_MOE_GATE_OFFSET = 69, + IN_BLOCK_SPARSE_MOE_GATE_SCALE = 70, + IN_BLOCK_SPARSE_MOE_GATE_COMPRESS_IDX = 71, + + IN_MLP_GATEUP_WEIGHT_EXPERT = 72, + IN_MLP_GATEUP_BIAS_EXPERT = 73, + IN_MLP_GATEUP_DESCALE_EXPERT = 74, + IN_MLP_GATEUP_OFFSET_EXPERT = 75, + IN_MLP_GATEUP_SCALE_EXPERT = 76, + IN_MLP_GATEUP_COMPRESS_IDX_EXPERT = 77, + + IN_MLP_DOWN_WEIGHT_EXPERT = 78, + IN_MLP_DOWN_BIAS_EXPERT = 79, + IN_MLP_DOWN_DESCALE_EXPERT = 80, + IN_MLP_DOWN_OFFSET_EXPERT = 81, + IN_MLP_DOWN_SCALE_EXPERT = 82, + IN_MLP_DOWN_COMPRESS_IDX_EXPERT = 83, +}; + +static std::vector> WEIGHT_MAPPING = {}; + +static const std::unordered_map WEIGHT_MAPPING_W8A8 = { + {"input_layernorm.weight", IN_INPUT_NORM_WEIGHT}, + {"input_layernorm.bias", IN_INPUT_NORM_BIAS}, + + {"self_attn.q_a_proj.weight", IN_Q_PROJ_A_WEIGHT}, + {"self_attn.q_a_proj.quant_bias", IN_Q_PROJ_A_BIAS}, + {"self_attn.q_a_proj.deq_scale", IN_Q_PROJ_A_DESCALE}, + {"self_attn.q_a_proj.input_offset", IN_Q_PROJ_A_OFFSET}, + {"self_attn.q_a_proj.input_scale", IN_Q_PROJ_A_SCALE}, + {"self_attn.q_a_layernorm.weight", IN_Q_PROJ_A_LAYERNORM_WEIGHT}, + {"self_attn.q_a_layernorm.bias", IN_Q_PROJ_A_LAYERNORM_BIAS}, + + {"self_attn.q_proj.weight", IN_Q_PROJ_B_WEIGHT}, + {"self_attn.q_b_proj.weight", IN_Q_PROJ_B_WEIGHT}, + {"self_attn.q_b_proj.quant_bias", IN_Q_PROJ_B_BIAS}, + {"self_attn.q_b_proj.input_scale", IN_Q_PROJ_B_SCALE}, + {"self_attn.q_b_proj.deq_scale", IN_Q_PROJ_B_DESCALE}, + {"self_attn.q_b_proj.input_offset", IN_Q_PROJ_B_OFFSET}, + + {"self_attn.kv_a_proj_with_mqa.weight", IN_KV_PROJ_WITH_MQA_WEIGHT}, + {"self_attn.kv_a_proj_with_mqa.quant_bias", IN_KV_PROJ_WITH_MQA_BIAS}, + {"self_attn.kv_a_proj_with_mqa.deq_scale", IN_KV_PROJ_WITH_MQA_DESCALE}, + {"self_attn.kv_a_proj_with_mqa.input_offset", IN_KV_PROJ_WITH_MQA_OFFSET}, + {"self_attn.kv_a_proj_with_mqa.input_scale", IN_KV_PROJ_WITH_MQA_SCALE}, + + {"self_attn.kv_a_layernorm.weight", IN_KV_PROJ_A_LAYERNORM_WEIGHT}, + {"self_attn.kv_a_layernorm.bias", IN_KV_PROJ_A_LAYERNORM_BIAS}, + + {"self_attn.kv_b_proj.weight", IN_K_PROJ_B_FOR_Q_WEIGHT}, // merge + // {"self_attn.kv_b_proj.weight", IN_V_PROJ_B_FOR_O_WEIGHT}, // merge + + {"self_attn.o_proj.weight", IN_ATTENTION_OUT_WEIGHT}, + {"self_attn.o_proj.quant_bias", IN_ATTENTION_OUT_BIAS}, + {"self_attn.o_proj.deq_scale", IN_ATTENTION_OUT_DESCALE}, + {"self_attn.o_proj.input_offset", IN_ATTENTION_OUT_OFFSET}, + {"self_attn.o_proj.input_scale", IN_ATTENTION_OUT_SCALE}, + + {"post_attention_layernorm.weight", IN_SELFATTENTION_OUT_NORM_WEIGHT}, + {"post_attention_layernorm.bias", IN_SELFATTENTION_OUT_NORM_BIAS}, + + {"mlp.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + {"mlp.gate_proj.weight_offset", IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, + {"mlp.gate_proj.weight_scale", IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, + + {"mlp.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + {"mlp.up_proj.weight_offset", IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, + {"mlp.up_proj.weight_scale", IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, + + {"mlp.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, + {"mlp.down_proj.weight_offset", IN_MLP_DOWN_OFFSET_SHARED_EXPERT}, + {"mlp.down_proj.weight_scale", IN_MLP_DOWN_SCALE_SHARED_EXPERT}, + + {"mlp.shared_experts.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + {"mlp.shared_experts.gate_proj.weight_offset", + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, + {"mlp.shared_experts.gate_proj.weight_scale", + IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, + + {"mlp.shared_experts.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + {"mlp.shared_experts.up_proj.weight_offset", + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, + {"mlp.shared_experts.up_proj.weight_scale", + IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, + + {"mlp.shared_experts.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, + {"mlp.shared_experts.down_proj.weight_offset", + IN_MLP_DOWN_OFFSET_SHARED_EXPERT}, + {"mlp.shared_experts.down_proj.weight_scale", + IN_MLP_DOWN_SCALE_SHARED_EXPERT}, + + {"mlp.gate.weight", IN_BLOCK_SPARSE_MOE_GATE_WEIGHT}, + {"mlp.gate.e_score_correction_bias", IN_BLOCK_SPARSE_MOE_GATE_BIAS}, + + {"gate_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, + {"gate_proj.weight_offset", IN_MLP_GATEUP_OFFSET_EXPERT}, + {"gate_proj.weight_scale", IN_MLP_GATEUP_SCALE_EXPERT}, + {"up_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, + {"up_proj.weight_offset", IN_MLP_GATEUP_OFFSET_EXPERT}, + {"up_proj.weight_scale", IN_MLP_GATEUP_SCALE_EXPERT}, + + {"down_proj.weight", IN_MLP_DOWN_WEIGHT_EXPERT}, + {"down_proj.weight_offset", IN_MLP_DOWN_OFFSET_EXPERT}, + {"down_proj.weight_scale", IN_MLP_DOWN_SCALE_EXPERT}, +}; + +static const std::map WEIGHT_SHARD = {}; + +static const std::map WEIGHT_SHARD_W8A8 = { + {IN_Q_PROJ_B_WEIGHT, 0}, + {IN_Q_PROJ_B_BIAS, 0}, + {IN_Q_PROJ_B_DESCALE, 0}, + {IN_K_PROJ_B_FOR_Q_WEIGHT, 0}, + {IN_V_PROJ_B_FOR_O_WEIGHT, 0}, + {IN_ATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, 0}, + {IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, 0}, + {IN_MLP_GATEUP_SCALE_SHARED_EXPERT, 0}, + {IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, 1}, + {IN_MLP_GATEUP_WEIGHT_EXPERT, 0}, + {IN_MLP_GATEUP_OFFSET_EXPERT, 0}, + {IN_MLP_GATEUP_SCALE_EXPERT, 0}, + {IN_MLP_DOWN_WEIGHT_EXPERT, 1}, +}; + +static std::vector SQUEEZE_WEIGHT_VEC = { + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, + IN_MLP_GATEUP_SCALE_SHARED_EXPERT, + IN_MLP_DOWN_OFFSET_SHARED_EXPERT, + IN_MLP_DOWN_SCALE_SHARED_EXPERT}; + +static std::vector LINEAR_FOR_ROPE = { + "self_attn.q_b_proj.weight", + "self_attn.q_b_proj.quant_bias", + "self_attn.q_b_proj.deq_scale", + "self_attn.kv_a_proj_with_mqa.weight", + "self_attn.kv_a_proj_with_mqa.quant_bias", + "self_attn.kv_a_proj_with_mqa.deq_scale", +}; + +DeekseekV2DecoderLoader::DeekseekV2DecoderLoader( + uint64_t weight_count, + const ModelContext& context, + int32_t layer_id, + int32_t prefill_firstKDenseReplace, + int32_t prefill_numOfDeviceExperts, + int32_t prefill_qkRopeHeadDim, + int32_t prefill_numAttentionHeadsPerRank, + int32_t decode_worldSize, + int32_t qk_nope_head_dim, + int32_t kv_lora_rank, + int32_t num_key_value_heads, + int32_t v_head_dim, + bool prefill_isBF16, + bool decode_isBF16) + : BaseLoader(weight_count, context), + layer_id_(layer_id), + prefill_firstKDenseReplace_(prefill_firstKDenseReplace), + prefill_numOfDeviceExperts_(prefill_numOfDeviceExperts), + prefill_qkRopeHeadDim_(prefill_qkRopeHeadDim), + prefill_numAttentionHeadsPerRank_(prefill_numAttentionHeadsPerRank), + decode_worldSize_(decode_worldSize), + qk_nope_head_dim_(qk_nope_head_dim), + kv_lora_rank_(kv_lora_rank), + num_key_value_heads_(num_key_value_heads), + v_head_dim_(v_head_dim), + prefill_isBF16_(prefill_isBF16), + decode_isBF16_(decode_isBF16) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + rank_ = parallel_args_.rank(); + first_k_dense_replace_ = model_args.first_k_dense_replace(); + n_layers_ = model_args.n_layers(); + num_experts_ = model_args.n_routed_experts(); + localWorldSize_ = parallel_args_.mapping().localWorldSize(); + ep_size_ = parallel_args_.ep_size(); + ep_local_tp_size_ = parallel_args_.world_size() / ep_size_; + CHECK_EQ(parallel_args_.world_size(), ep_size_ * ep_local_tp_size_); + ep_local_tp_rank_ = parallel_args_.rank() % ep_local_tp_size_; + num_experts_per_partition_ = model_args.n_routed_experts() / ep_size_; + redundant_experts_num_ = FLAGS_redundant_experts_num; + if (FLAGS_enable_eplb) { + num_experts_per_partition_ += redundant_experts_num_; + } + ep_rank_ = parallel_args_.rank() / ep_local_tp_size_; + start_expert_id_ = ep_rank_ * num_experts_per_partition_; + end_expert_id_ = start_expert_id_ + num_experts_per_partition_ - 1; + initialize_tensors(options); + initialize_weight_tensors(options); +} + +void DeekseekV2DecoderLoader::initialize_tensors( + const torch::TensorOptions& options) { + tensor_placeholder_ = torch::zeros({1}).to(options); + reserve_experts_weights(prefill_numOfDeviceExperts_); + initialize_device_expert_list(decode_worldSize_, num_experts_per_partition_); +} + +void DeekseekV2DecoderLoader::load_state_dict(const StateDict& state_dict) { + for (const auto& [name, tensor] : state_dict) { + bool is_sharded = false; + int index = 0; + + if (absl::EndsWith(name, "self_attn.kv_b_proj.weight")) { + index = WEIGHT_MAPPING_W8A8.at(name); + set_kv_weight(state_dict, name, index, WEIGHT_SHARD_W8A8.at(index)); + continue; + } + + if (absl::StartsWith(name, "mlp.experts")) { + process_expert_weights(state_dict, name, tensor); + continue; + } + + if (absl::StartsWith(name, "mlp.shared_experts")) { + process_shared_expert_weights(state_dict, name, tensor); + continue; + } + + if (absl::StartsWith(name, "mlp") && !absl::StrContains(name, "gate.")) { + process_mlp_common_weights(state_dict, name, tensor); + continue; + } + + process_general_weights(state_dict, name, tensor); + } +} + +void DeekseekV2DecoderLoader::verify_loaded_weights( + const std::string& prefix) const { + for (const auto& [index, name] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << prefix + name; + } +} + +int DeekseekV2DecoderLoader::extract_expert_index(const std::string& name) { + std::string prefix = "experts."; + size_t pos = name.find(prefix); + if (pos != std::string::npos) { + pos += prefix.length(); + size_t end_pos = pos; + while (end_pos < name.length() && std::isdigit(name[end_pos])) { + ++end_pos; + } + if (end_pos > pos) { + return std::stoi(name.substr(pos, end_pos - pos)); + } + } + return -1; +} + +void DeekseekV2DecoderLoader::process_expert_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + // Step 1: Early checks and basic info extraction + int expert_index = extract_expert_index(name); + const std::string suffix = extract_endswith(name); + const int index = get_mapped_index(suffix, WEIGHT_MAPPING_W8A8); + if (index == -1) { + return; + } + + const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); + const bool needs_eplb = FLAGS_enable_eplb && (rank_ % localWorldSize_ == + expert_index % localWorldSize_); + + // Step 2: Check if expert is in partition + const int start_idx = ep_rank_ * num_experts_per_partition_; + const int end_idx = (ep_rank_ + 1) * num_experts_per_partition_; + const int safe_end = + std::min(end_idx, static_cast(device_expert_list_.size())); + + auto it = std::find(device_expert_list_.begin() + start_idx, + device_expert_list_.begin() + safe_end, + expert_index); + const bool in_partition = it != device_expert_list_.begin() + safe_end; + + // Early return if neither EPLB nor partition needs this expert + if (!needs_eplb && !in_partition) { + return; + } + + // Step 3: Process tensor + torch::Tensor processed_tensor; + { + std::lock_guard lock(experts_mutex_); + processed_tensor = is_sharded + ? get_sharded_tensor(state_dict, + name, + WEIGHT_SHARD_W8A8.at(index), + ep_local_tp_rank_, + ep_local_tp_size_) + : tensor; + + if (!decode_isBF16_) { + if (absl::EndsWith(name, "_offset")) { + processed_tensor = processed_tensor.to(torch::kFloat16); + } else if (absl::EndsWith(name, "_scale")) { + processed_tensor = processed_tensor.to(torch::kFloat32); + } + } + } + + // Step 4: Handle EPLB case + if (needs_eplb) { + std::lock_guard lock(experts_mutex_); + std::string shm_key = get_expert_shm_key(layer_id_, expert_index, suffix); + shared_buffer_->add_tensor(expert_index, + layer_id_ - first_k_dense_replace_, + shm_key, + processed_tensor.contiguous()); + } + + // Step 5: Handle partition case + if (in_partition) { + std::vector matches_pos; + for (auto iter = it; iter != device_expert_list_.begin() + safe_end; + ++iter) { + if (*iter == expert_index) { + matches_pos.emplace_back( + std::distance(device_expert_list_.begin(), iter) - start_idx); + } + } + + if (!matches_pos.empty()) { + std::lock_guard lock(experts_mutex_); + for (auto pos : matches_pos) { + experts_weights_[suffix][pos] = processed_tensor.clone(); + } + } + } +} + +void DeekseekV2DecoderLoader::initialize_weight_tensors( + const torch::TensorOptions& options) { + for (int i = 0; i < weight_count_; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } + + if (FLAGS_enable_eplb) { + const int64_t size = + 50LL * 1024LL * 1024LL * int64_t(n_layers_ - first_k_dense_replace_); + shared_buffer_ = std::make_unique( + num_experts_, n_layers_ - first_k_dense_replace_, size); + } +} + +void DeekseekV2DecoderLoader::convert_offsets_to_int8() { + auto convert_to_int8 = [this](int index) { + at_weight_tensors_[index] = + at_weight_tensors_[index].to(torch::kInt8).to(device_); + }; + convert_to_int8(IN_Q_PROJ_A_OFFSET); + convert_to_int8(IN_Q_PROJ_B_OFFSET); + convert_to_int8(IN_KV_PROJ_WITH_MQA_OFFSET); + convert_to_int8(IN_ATTENTION_OUT_OFFSET); +} + +void DeekseekV2DecoderLoader::handle_device_specific_bias() { + if (dp_local_tp_rank_ != 0) { + torch::Tensor original_tensor = at_weight_tensors_[IN_ATTENTION_OUT_BIAS]; + at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = + torch::zeros(original_tensor.sizes(), + torch::TensorOptions() + .dtype(original_tensor.dtype()) + .device(original_tensor.device())); + } +} + +std::string DeekseekV2DecoderLoader::extract_endswith( + const std::string& input) { + std::vector parts; + std::stringstream ss(input); + std::string part; + while (std::getline(ss, part, '.')) { + parts.emplace_back(part); + } + if (parts.size() < 2) { + return ""; + } + std::string result = parts[parts.size() - 2] + "." + parts[parts.size() - 1]; + return result; +} + +torch::Tensor DeekseekV2DecoderLoader::get_sharded_tensor( + const StateDict& state_dict, + const std::string& name, + int dim) { + if (parallel_args_.world_size() > 1) { + return state_dict.get_sharded_tensor( + name, dim, parallel_args_.rank(), parallel_args_.world_size()); + } else { + return state_dict.get_tensor(name); + } +} + +torch::Tensor DeekseekV2DecoderLoader::get_sharded_tensor( + const StateDict& state_dict, + const std::string& name, + int dim, + int loacal_tp_rank, + int local_tp_size) { + if (local_tp_size > 1) { + return state_dict.get_sharded_tensor( + name, dim, loacal_tp_rank, local_tp_size); + } else { + return state_dict.get_tensor(name); + } +} + +int DeekseekV2DecoderLoader::get_mapped_index( + const std::string& name, + const std::unordered_map& mapping) { + const auto it = mapping.find(name); + if (it == mapping.end()) { + LOG(WARNING) << "Parameter '" << name + << "' not found in mapping and will not be used."; + return -1; + } + return it->second; +} + +void DeekseekV2DecoderLoader::squeeze_experts_weights() { + for (const auto& index : SQUEEZE_WEIGHT_VEC) { + if (at_weight_tensors_[index].dim() > 1) { + at_weight_tensors_[index] = at_weight_tensors_[index].squeeze(); + } + } +} + +void DeekseekV2DecoderLoader::process_general_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + const int index = get_mapped_index(name, WEIGHT_MAPPING_W8A8); + if (index == -1) { + return; + } + const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); + torch::Tensor tmp_tensor; + + tmp_tensor = is_sharded ? get_sharded_tensor(state_dict, + name, + WEIGHT_SHARD_W8A8.at(index), + dp_local_tp_rank_, + dp_local_tp_size_) + .to(device_) + : tensor.to(device_); + + correct_tensor_dtype(tmp_tensor, name); + at_weight_tensors_[index] = tmp_tensor; +} + +void DeekseekV2DecoderLoader::process_mlp_common_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + const int index = get_mapped_index(name, WEIGHT_MAPPING_W8A8); + if (index == -1) { + return; + } + const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); + std::lock_guard lock(shared_experts_mutex_); + + torch::Tensor tmp_tensor = + is_sharded ? get_sharded_tensor(state_dict, + name, + WEIGHT_SHARD_W8A8.at(index), + dp_local_tp_rank_, + dp_local_tp_size_) + .to(device_) + : tensor.to(device_); + if (absl::StrContains(name, "down_proj")) { + at_weight_tensors_[index] = tmp_tensor; + } else { + shared_experts_weights_[name] = tmp_tensor; + } +} + +void DeekseekV2DecoderLoader::merge_experts_weights() { + torch::Tensor mlp_gateup_weight = + merge_experts_weights(experts_weights_["gate_proj.weight"], + experts_weights_["up_proj.weight"], + device_, + /*transpose=*/true); + at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT] = + at_npu::native::npu_format_cast(mlp_gateup_weight, 29); + // at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT] = + // at_npu::native::npu_format_cast(mlp_gateup_weight, 2).contiguous(); + if (quantize_type_ == "w8a8_dynamic") { + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = + merge_experts_weights(experts_weights_["gate_proj.weight_offset"], + experts_weights_["up_proj.weight_offset"], + device_); + at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = + merge_experts_weights(experts_weights_["gate_proj.weight_scale"], + experts_weights_["up_proj.weight_scale"], + device_); + } + +#if defined(USE_A3) + torch::Tensor mlp_down_weight = + merge_experts_weights(experts_weights_["down_proj.weight"], + device_, + /*transpose=*/false); + // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = + // at_npu::native::npu_format_cast(mlp_down_weight, 29); + at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = + at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); +#else + // TODO: xllm ops's GMM need to support MTP. + if (decode_isBF16_ && false) { + torch::Tensor mlp_down_weight = + merge_experts_weights(experts_weights_["down_proj.weight"], + device_, + /*transpose=*/true); + at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = + at_npu::native::npu_format_cast(mlp_down_weight, 29); + } else { + torch::Tensor mlp_down_weight = + merge_experts_weights(experts_weights_["down_proj.weight"], + device_, + /*transpose=*/false); + at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = + at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); + } +#endif + if (quantize_type_ == "w8a8_dynamic") { + at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = merge_experts_weights( + experts_weights_["down_proj.weight_offset"], device_); + at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = merge_experts_weights( + experts_weights_["down_proj.weight_scale"], device_); + } +} + +torch::Tensor DeekseekV2DecoderLoader::merge_experts_weights( + std::vector& experts, + at::Device device, + bool transpose) { + torch::Tensor merged_tensor = torch::stack(experts, 0).to(device); + if (transpose) { + merged_tensor = merged_tensor.transpose(1, 2); + } + merged_tensor = merged_tensor.contiguous(); + experts.clear(); + return merged_tensor; +} + +torch::Tensor DeekseekV2DecoderLoader::merge_experts_weights( + std::vector& experts_gate, + std::vector& experts_up, + at::Device device, + bool transpose) { + for (size_t i = 0; i < experts_up.size(); ++i) { + experts_gate[i] = torch::cat({experts_gate[i], experts_up[i]}, 0); + } + + torch::Tensor merged_tensor = torch::stack(experts_gate, 0).to(device); + + if (transpose) { + merged_tensor = merged_tensor.transpose(1, 2); + } + + merged_tensor = merged_tensor.contiguous(); + experts_gate.clear(); + experts_up.clear(); + return merged_tensor; +} + +void DeekseekV2DecoderLoader::process_shared_expert_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + torch::Tensor tmp_tensor; + std::lock_guard lock(shared_experts_mutex_); + const int index = get_mapped_index(name, WEIGHT_MAPPING_W8A8); + if (index == -1) { + return; + } + if (FLAGS_expert_parallel_degree == 2) { + tmp_tensor = tensor.to(device_); + } else { + const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); + tmp_tensor = is_sharded ? get_sharded_tensor( + state_dict, name, WEIGHT_SHARD_W8A8.at(index)) + .to(device_) + : tensor.to(device_); + } + if (absl::StrContains(name, "down_proj")) { + at_weight_tensors_[index] = tmp_tensor; + } else { + shared_experts_weights_[name] = tmp_tensor; + } +} + +void DeekseekV2DecoderLoader::set_kv_weight(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position, + int dim) { + torch::Tensor mutable_tensor; + if (parallel_args_.world_size() <= 1) { + mutable_tensor = state_dict.get_tensor(tensor_name).to(device_); + correct_tensor_dtype(mutable_tensor, tensor_name); + } else { + mutable_tensor = + get_sharded_tensor( + state_dict, tensor_name, dim, dp_local_tp_rank_, dp_local_tp_size_) + .to(device_); + // mutable_tensor = get_sharded_tensor(state_dict, tensor_name, dim); + correct_tensor_dtype(mutable_tensor, tensor_name); + } + + torch::Tensor kv_b_proj_weight = + mutable_tensor.reshape({num_key_value_heads_ / dp_local_tp_size_, + qk_nope_head_dim_ + v_head_dim_, + kv_lora_rank_}); + torch::Tensor k_b_proj_preprocessed = + kv_b_proj_weight.slice(1, 0, qk_nope_head_dim_).contiguous(); + torch::Tensor v_b_proj_preprocessed = + kv_b_proj_weight + .slice(1, qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_) + .transpose(1, 2) + .contiguous(); + at_weight_tensors_[weight_position] = k_b_proj_preprocessed.to(device_); + at_weight_tensors_[weight_position + 6] = v_b_proj_preprocessed.to(device_); +} + +void DeekseekV2DecoderLoader::preprocess_linear_for_rope() { + for (const auto& name : LINEAR_FOR_ROPE) { + if (quantize_type_ == "") { + if (!absl::EndsWith(name, "weight")) { + continue; + } + } + int index = WEIGHT_MAPPING_W8A8.at(name); + at_weight_tensors_[index] = + view_tensor(at_weight_tensors_[index], name, true); + at_weight_tensors_[index] = trans_rope_weight(at_weight_tensors_[index]); + at_weight_tensors_[index] = + (!absl::EndsWith(name, "weight")) + ? view_tensor(at_weight_tensors_[index], name, false).flatten() + : view_tensor(at_weight_tensors_[index], name, false); + } +} + +torch::Tensor DeekseekV2DecoderLoader::view_tensor(torch::Tensor weight, + const std::string& name, + bool pre_view) { + if (absl::StrContains(name, "q_b_proj")) { + if (pre_view) { + return weight + .view({prefill_numAttentionHeadsPerRank_, + qk_nope_head_dim_ + prefill_qkRopeHeadDim_, + -1}) + .contiguous(); + } else { + return weight + .view({prefill_numAttentionHeadsPerRank_ * + (qk_nope_head_dim_ + prefill_qkRopeHeadDim_), + -1}) + .contiguous(); + } + } else if (absl::StrContains(name, "kv_a_proj_with_mqa")) { + return weight.view({kv_lora_rank_ + prefill_qkRopeHeadDim_, -1}) + .contiguous(); + } + return weight; +} + +torch::Tensor DeekseekV2DecoderLoader::trans_rope_weight(torch::Tensor weight) { + int64_t d = weight.size(-2); + int64_t rope_dim = prefill_qkRopeHeadDim_; + torch::Tensor weight_1 = + weight.slice(-2, d - rope_dim, torch::indexing::None, 2).contiguous(); + + torch::Tensor weight_2 = + weight.slice(-2, d - rope_dim + 1, torch::indexing::None, 2).contiguous(); + + torch::Tensor combined = torch::cat({weight_1, weight_2}, -2); + + weight.slice(-2, d - rope_dim, d).copy_(combined); + + return weight.contiguous(); +} + +void DeekseekV2DecoderLoader::initialize_device_expert_list( + int num_device, + int num_device_expert) { + int32_t num_device_route_expert = num_device_expert; + if (FLAGS_enable_eplb) { + num_device_route_expert = num_device_expert - redundant_experts_num_; + } + for (int i = 0; i < num_device * num_device_route_expert; ++i) { + device_expert_list_.emplace_back(i); + if (FLAGS_enable_eplb && (i + 1) % num_device_route_expert == 0) { + for (int redundant_expert = 0; redundant_expert < redundant_experts_num_; + ++redundant_expert) + device_expert_list_.emplace_back(i); + } + } +} + +torch::Tensor DeekseekV2DecoderLoader::convert_fp16_to_int64( + const torch::Tensor& fp16_tensor) { + auto float_tensor = fp16_tensor.to(torch::kFloat32); + auto int32_tensor = float_tensor.view(torch::kInt32); + auto int64_tensor = int32_tensor.to(torch::kInt64); + return int64_tensor; +} + +void DeekseekV2DecoderLoader::convert_descaled_weights_to_float() { + auto convert_to_float = [this](int index) { + at_weight_tensors_[index] = at_weight_tensors_[index].to(torch::kFloat32); + }; + convert_to_float(IN_Q_PROJ_A_DESCALE); + convert_to_float(IN_Q_PROJ_B_DESCALE); + convert_to_float(IN_KV_PROJ_WITH_MQA_DESCALE); + convert_to_float(IN_ATTENTION_OUT_DESCALE); +} + +void DeekseekV2DecoderLoader::reserve_experts_weights( + int num_of_device_experts) { + experts_weights_.clear(); + std::vector weight_names = { + "gate_proj.weight", "up_proj.weight", "down_proj.weight"}; + if (quantize_type_ == "w8a8_dynamic") { + weight_names.emplace_back("gate_proj.weight_offset"); + weight_names.emplace_back("up_proj.weight_offset"); + weight_names.emplace_back("down_proj.weight_offset"); + weight_names.emplace_back("gate_proj.weight_scale"); + weight_names.emplace_back("up_proj.weight_scale"); + weight_names.emplace_back("down_proj.weight_scale"); + } + std::lock_guard lock(experts_mutex_); + for (const auto& weight_name : weight_names) { + experts_weights_[weight_name] = + std::vector(num_of_device_experts); + } +} + +std::string DeekseekV2DecoderLoader::get_expert_shm_key( + int32_t layer_id, + int32_t expert_index, + const std::string& suffix) { + std::string shm_key = + "layer_" + std::to_string(layer_id - first_k_dense_replace_) + "_" + + "expert_" + std::to_string(expert_index) + "_" + suffix; + return shm_key; +} + +void DeekseekV2DecoderLoader::merge_shared_experts_weights() { + auto merge_and_clear = [this](int index, + torch::Tensor& shared_experts_gate, + torch::Tensor& shared_experts_up) { + at_weight_tensors_[index] = + torch::cat({shared_experts_gate, shared_experts_up}, 0) + .to(device_) + .contiguous(); + shared_experts_gate = tensor_placeholder_; + shared_experts_up = tensor_placeholder_; + }; + + if (layer_id_ >= prefill_firstKDenseReplace_) { + merge_and_clear( + IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, + shared_experts_weights_["mlp.shared_experts.gate_proj.weight"], + shared_experts_weights_["mlp.shared_experts.up_proj.weight"]); + if (quantize_type_ == "w8a8_dynamic") { + merge_and_clear( + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, + shared_experts_weights_["mlp.shared_experts.gate_proj.weight_offset"], + shared_experts_weights_["mlp.shared_experts.up_proj.weight_offset"]); + merge_and_clear( + IN_MLP_GATEUP_SCALE_SHARED_EXPERT, + shared_experts_weights_["mlp.shared_experts.gate_proj.weight_scale"], + shared_experts_weights_["mlp.shared_experts.up_proj.weight_scale"]); + } + } else { + merge_and_clear(IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, + shared_experts_weights_["mlp.gate_proj.weight"], + shared_experts_weights_["mlp.up_proj.weight"]); + if (quantize_type_ == "w8a8_dynamic") { + merge_and_clear(IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, + shared_experts_weights_["mlp.gate_proj.weight_offset"], + shared_experts_weights_["mlp.up_proj.weight_offset"]); + merge_and_clear(IN_MLP_GATEUP_SCALE_SHARED_EXPERT, + shared_experts_weights_["mlp.gate_proj.weight_scale"], + shared_experts_weights_["mlp.up_proj.weight_scale"]); + } + } +} + +void DeekseekV2DecoderLoader::merge_loaded_weights() { + if (quantize_type_ == "w8a8_dynamic") { + if (prefill_isBF16_) { + convert_descaled_weights_to_float(); + } + convert_offsets_to_int8(); + handle_device_specific_bias(); + } + + merge_shared_experts_weights(); + if (layer_id_ >= prefill_firstKDenseReplace_) { + merge_experts_weights(); + } + + squeeze_experts_weights(); + + preprocess_linear_for_rope(); + + at_weight_tensors_[IN_Q_PROJ_A_WEIGHT] = + torch::cat({at_weight_tensors_[IN_KV_PROJ_WITH_MQA_WEIGHT], + at_weight_tensors_[IN_Q_PROJ_A_WEIGHT]}, + 0) + .contiguous(); + if (quantize_type_ == "w8a8_dynamic") { + at_weight_tensors_[IN_Q_PROJ_A_BIAS] = + torch::cat({at_weight_tensors_[IN_KV_PROJ_WITH_MQA_BIAS], + at_weight_tensors_[IN_Q_PROJ_A_BIAS]}, + 0) + .contiguous(); + at_weight_tensors_[IN_Q_PROJ_A_DESCALE] = + torch::cat({at_weight_tensors_[IN_KV_PROJ_WITH_MQA_DESCALE], + at_weight_tensors_[IN_Q_PROJ_A_DESCALE]}, + 0) + .contiguous(); + } + + at_weight_tensors_[IN_Q_PROJ_A_WEIGHT] = at_npu::native::npu_format_cast( + at_weight_tensors_[IN_Q_PROJ_A_WEIGHT], 29); + at_weight_tensors_[IN_Q_PROJ_B_WEIGHT] = at_npu::native::npu_format_cast( + at_weight_tensors_[IN_Q_PROJ_B_WEIGHT], 29); + + at_weight_tensors_[IN_KV_PROJ_WITH_MQA_WEIGHT] = tensor_placeholder_; + at_weight_tensors_[IN_KV_PROJ_WITH_MQA_BIAS] = tensor_placeholder_; + at_weight_tensors_[IN_KV_PROJ_WITH_MQA_DESCALE] = tensor_placeholder_; + at_weight_tensors_[IN_KV_PROJ_WITH_MQA_OFFSET] = tensor_placeholder_; + at_weight_tensors_[IN_KV_PROJ_WITH_MQA_SCALE] = tensor_placeholder_; + if (FLAGS_expert_parallel_degree != 2) { + at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = + torch::roll(at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT], + {-1 * ep_rank_ * num_experts_per_partition_}, + {0}) + .contiguous(); + at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_BIAS] = + torch::roll(at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_BIAS], + {-1 * ep_rank_ * num_experts_per_partition_}, + {0}) + .contiguous(); + } + // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_SHARED_EXPERT] = + // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_SHARED_EXPERT].transpose(0, 1); + at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = + at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT].to(torch::kFloat32); + if (quantize_type_ == "w8a8_dynamic") { + // at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = + // at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT].to(torch::kFloat32); + if (!prefill_isBF16_) { + at_weight_tensors_[IN_Q_PROJ_A_DESCALE] = + convert_fp16_to_int64(at_weight_tensors_[IN_Q_PROJ_A_DESCALE]); + at_weight_tensors_[IN_Q_PROJ_B_DESCALE] = + convert_fp16_to_int64(at_weight_tensors_[IN_Q_PROJ_B_DESCALE]); + at_weight_tensors_[IN_ATTENTION_OUT_DESCALE] = + convert_fp16_to_int64(at_weight_tensors_[IN_ATTENTION_OUT_DESCALE]); + + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT] = + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT].to( + torch::kFloat16); + at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT] = + at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT].to( + torch::kFloat32); + at_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT] = + at_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT].to( + torch::kFloat32); + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT].to(torch::kFloat16); + at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = + at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT].to(torch::kFloat32); + at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = + at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT].to(torch::kFloat16); + at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = + at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT].to(torch::kFloat32); + } + } +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/deepseek_v2_decoder_loader.h b/xllm/core/layers/npu/loader/deepseek_v2_decoder_loader.h new file mode 100644 index 000000000..d2afd5f8d --- /dev/null +++ b/xllm/core/layers/npu/loader/deepseek_v2_decoder_loader.h @@ -0,0 +1,161 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "base_loader.h" + +namespace xllm { +namespace layer { + +class DeekseekV2DecoderLoader : public BaseLoader { + public: + DeekseekV2DecoderLoader(uint64_t weight_count, + const ModelContext& context, + int32_t layer_id, + int32_t prefill_firstKDenseReplace, + int32_t prefill_numOfDeviceExperts, + int32_t prefill_qkRopeHeadDim, + int32_t prefill_numAttentionHeadsPerRank, + int32_t decode_worldSize, + int32_t qk_nope_head_dim_, + int32_t kv_lora_rank, + int32_t num_key_value_heads, + int32_t v_head_dim, + bool prefill_isBF16, + bool decode_isBF16); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights(const std::string& prefix) const override; + void merge_loaded_weights() override; + + protected: + void initialize_device_expert_list(int num_device, int num_device_expert); + + int extract_expert_index(const std::string& name); + + std::string get_expert_shm_key(int32_t layer_id, + int32_t expert_index, + const std::string& suffix); + + int get_mapped_index(const std::string& name, + const std::unordered_map& mapping); + + torch::Tensor get_sharded_tensor(const StateDict& state_dict, + const std::string& name, + int dim); + + torch::Tensor get_sharded_tensor(const StateDict& state_dict, + const std::string& name, + int dim, + int local_tp_rank, + int local_tp_size); + + std::string extract_endswith(const std::string& input); + + void set_kv_weight(const StateDict& state_dict, + const std::string& tensor_name, + int weight_position, + int dim); + + torch::Tensor convert_fp16_to_int64(); + + void preprocess_linear_for_rope(); + + void process_expert_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + void process_shared_expert_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + void process_mlp_common_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + void process_general_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + void convert_descaled_weights_to_float(); + + void convert_offsets_to_int8(); + + torch::Tensor convert_fp16_to_int64(const torch::Tensor& fp16_tensor); + + void handle_device_specific_bias(); + + void merge_shared_experts_weights(); + + void merge_experts_weights(); + + torch::Tensor merge_experts_weights(std::vector& experts, + at::Device device, + bool transpose = false); + + torch::Tensor merge_experts_weights(std::vector& experts_up, + std::vector& experts_gate, + at::Device device, + bool transpose = false); + + void squeeze_experts_weights(); + + void update_expert_weight(); + + void prepare_expert_weight(const std::vector& expert_list); + + void initialize_weight_tensors(const torch::TensorOptions& options); + + void initialize_tensors(const torch::TensorOptions& options); + + torch::Tensor view_tensor(torch::Tensor weight, + const std::string& name, + bool pre_view); + + void reserve_experts_weights(int num_of_device_experts); + + torch::Tensor trans_rope_weight(torch::Tensor weight); + + int32_t rank_; + int32_t first_k_dense_replace_; + int32_t n_layers_; + int32_t localWorldSize_; + int32_t ep_size_; + int32_t num_experts_; + int32_t num_experts_per_partition_; + int32_t ep_local_tp_size_; + int32_t ep_local_tp_rank_; + int32_t start_expert_id_; + int32_t end_expert_id_; + int32_t ep_rank_; + int32_t redundant_experts_num_; + + int32_t layer_id_; + int32_t qk_nope_head_dim_; + int32_t kv_lora_rank_; + int32_t v_head_dim_; + int32_t num_key_value_heads_; + int32_t prefill_firstKDenseReplace_; + int32_t prefill_numOfDeviceExperts_; + int32_t prefill_qkRopeHeadDim_; + int32_t prefill_numAttentionHeadsPerRank_; + int32_t decode_worldSize_; + bool prefill_isBF16_; + bool decode_isBF16_; + std::mutex shared_experts_mutex_; + std::mutex experts_mutex_; + + torch::Tensor tensor_placeholder_; +}; +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/glm4_moe_decoder_loader.cpp b/xllm/core/layers/npu/loader/glm4_moe_decoder_loader.cpp new file mode 100644 index 000000000..4072f3b50 --- /dev/null +++ b/xllm/core/layers/npu/loader/glm4_moe_decoder_loader.cpp @@ -0,0 +1,805 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "glm4_moe_decoder_loader.h" + +#include +#include +#include +#include + +#include +#include + +#include "core/layers/npu/npu_glm4_moe_decoder_layer.h" + +namespace xllm { +namespace layer { + +enum DecoderLayerTensorId : int { + IN_INPUT_NORM_WEIGHT = 0, + IN_INPUT_NORM_BIAS = 1, + IN_INPUT_NORM_NEW_WEIGHT = 2, + IN_INPUT_NORM_NEW_BIAS = 3, + + IN_QKV_WEIGHT_0 = 4, + IN_QKV_BIAS_0 = 5, + IN_QKV_DESCALE_0 = 6, + IN_QKV_OFFSET_0 = 7, + IN_QKV_SCALE_0 = 8, + IN_QKV_COMPRESS_IDX_0 = 9, + + IN_QKV_WEIGHT_1 = 10, + IN_QKV_BIAS_1 = 11, + IN_QKV_DESCALE_1 = 12, + IN_QKV_OFFSET_1 = 13, + IN_QKV_SCALE_1 = 14, + IN_QKV_COMPRESS_IDX_1 = 15, + + IN_QKV_WEIGHT_2 = 16, + IN_QKV_BIAS_2 = 17, + IN_QKV_DESCALE_2 = 18, + IN_QKV_OFFSET_2 = 19, + IN_QKV_SCALE_2 = 20, + IN_QKV_COMPRESS_IDX_2 = 21, + + IN_QKV_DENSE_WEIGHT = 22, + IN_QKV_DENSE_BIAS = 23, + IN_QKV_DENSE_DESCALE = 24, + IN_QKV_DENSE_OFFSET = 25, + IN_QKV_DENSE_SCALE = 26, + IN_QKV_DENSE_COMPRESS_IDX = 27, + + IN_POST_ATTN_NORM_WEIGHT = 28, + IN_POST_ATTN_NORM_BIAS = 29, + IN_POST_ATTN_NORM_NEW_WEIGHT = 30, + IN_POST_ATTN_NORM_NEW_BIAS = 31, + + IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT = 32, + IN_MLP_GATEUP_BIAS_SHARED_EXPERT = 33, + IN_MLP_GATEUP_DESCALE_SHARED_EXPERT = 34, + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT = 35, + IN_MLP_GATEUP_SCALE_SHARED_EXPERT = 36, + IN_MLP_GATEUP_COMPRESS_IDX_SHARED_EXPERT = 37, + + IN_MLP_DOWN_WEIGHT_SHARED_EXPERT = 38, + IN_MLP_DOWN_BIAS_SHARED_EXPERT = 39, + IN_MLP_DOWN_DESCALE_SHARED_EXPERT = 40, + IN_MLP_DOWN_OFFSET_SHARED_EXPERT = 41, + IN_MLP_DOWN_SCALE_SHARED_EXPERT = 42, + IN_MLP_DOWN_COMPRESS_IDX_SHARED_EXPERT = 43, + + IN_SHARED_EXPERT_GATE_WEIGHT = 44, + IN_SHARED_EXPERT_GATE_BIAS = 45, + IN_SHARED_EXPERT_GATE_DESCALE = 46, + IN_SHARED_EXPERT_GATE_OFFSET = 47, + IN_SHARED_EXPERT_GATE_SCALE = 48, + IN_SHARED_EXPERT_GATE_COMPRESS_IDX = 49, + + BLOCK_SPARSE_MOE_GATE_WEIGHT = 50, + BLOCK_SPARSE_MOE_GATE_BIAS = 51, + BLOCK_SPARSE_MOE_GATE_DESCALE = 52, + BLOCK_SPARSE_MOE_GATE_OFFSET = 53, + BLOCK_SPARSE_MOE_GATE_SCALE = 54, + BLOCK_SPARSE_MOE_GATE_COMPRESS_IDX = 55, + + IN_MLP_GATEUP_WEIGHT = 56, + IN_MLP_GATEUP_BIAS = 57, + IN_MLP_GATEUP_DESCALE = 58, + IN_MLP_GATEUP_OFFSET = 59, + IN_MLP_GATEUP_SCALE = 60, + IN_MLP_GATEUP_COMPRESS_IDX = 61, + + IN_MLP_DOWN_WEIGHT = 62, + IN_MLP_DOWN_BIAS = 63, + IN_MLP_DOWN_DESCALE = 64, + IN_MLP_DOWN_OFFSET = 65, + IN_MLP_DOWN_SCALE = 66, + IN_MLP_DOWN_COMPRESS_IDX = 67, + + Q_NORM_WEIGHT = 68, + K_NORM_WEIGHT = 69 +}; + +static std::unordered_map WEIGHT_MAPPING = { + {"input_layernorm.weight", IN_INPUT_NORM_WEIGHT}, + + {"self_attn.q_proj.weight", IN_QKV_WEIGHT_0}, + {"self_attn.q_proj.bias", IN_QKV_BIAS_0}, + + {"self_attn.k_proj.weight", IN_QKV_WEIGHT_1}, + {"self_attn.k_proj.bias", IN_QKV_BIAS_1}, + + {"self_attn.v_proj.weight", IN_QKV_WEIGHT_2}, + {"self_attn.v_proj.bias", IN_QKV_BIAS_2}, + + {"self_attn.o_proj.weight", IN_QKV_DENSE_WEIGHT}, + + {"post_attention_layernorm.weight", IN_POST_ATTN_NORM_WEIGHT}, + + // mlp or shared expert + {"mlp.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + + {"mlp.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + + {"mlp.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, + + {"mlp.shared_experts.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + + {"mlp.shared_experts.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + + {"mlp.shared_experts.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, + + // MoE Gate + {"mlp.gate.weight", BLOCK_SPARSE_MOE_GATE_WEIGHT}, + {"mlp.gate.e_score_correction_bias", BLOCK_SPARSE_MOE_GATE_BIAS}, + + // Expert MLP - Gate/Up projections + {"gate_proj.weight", IN_MLP_GATEUP_WEIGHT}, + {"up_proj.weight", IN_MLP_GATEUP_WEIGHT}, + + // Expert MLP - Down projection + {"down_proj.weight", IN_MLP_DOWN_WEIGHT}, + +}; + +static std::unordered_map WEIGHT_MAPPING_W8A8 = { + {"input_layernorm.weight", IN_INPUT_NORM_WEIGHT}, + {"input_layernorm.bias", IN_INPUT_NORM_NEW_BIAS}, + + {"self_attn.q_proj.weight", IN_QKV_WEIGHT_0}, + {"self_attn.q_proj.deq_scale", IN_QKV_DESCALE_0}, + {"self_attn.q_proj.quant_bias", IN_QKV_BIAS_0}, + {"self_attn.q_proj.input_offset", IN_QKV_OFFSET_0}, + {"self_attn.q_proj.input_scale", IN_QKV_SCALE_0}, + + {"self_attn.k_proj.weight", IN_QKV_WEIGHT_1}, + {"self_attn.k_proj.deq_scale", IN_QKV_DESCALE_1}, + {"self_attn.k_proj.quant_bias", IN_QKV_BIAS_1}, + + {"self_attn.v_proj.weight", IN_QKV_WEIGHT_2}, + {"self_attn.v_proj.deq_scale", IN_QKV_DESCALE_2}, + {"self_attn.v_proj.quant_bias", IN_QKV_BIAS_2}, + + {"self_attn.o_proj.weight", IN_QKV_DENSE_WEIGHT}, + {"self_attn.o_proj.quant_bias", IN_QKV_DENSE_BIAS}, + {"self_attn.o_proj.deq_scale", IN_QKV_DENSE_DESCALE}, + {"self_attn.o_proj.weight_offset", IN_QKV_DENSE_OFFSET}, + {"self_attn.o_proj.weight_scale", IN_QKV_DENSE_SCALE}, + + {"post_attention_layernorm.weight", IN_POST_ATTN_NORM_WEIGHT}, + {"post_attention_layernorm.bias", IN_POST_ATTN_NORM_NEW_BIAS}, + + // mlp + {"mlp.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + {"mlp.gate_proj.weight_offset", IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, + {"mlp.gate_proj.weight_scale", IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, + + {"mlp.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + {"mlp.up_proj.weight_offset", IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, + {"mlp.up_proj.weight_scale", IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, + + {"mlp.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, + {"mlp.down_proj.weight_offset", IN_MLP_DOWN_OFFSET_SHARED_EXPERT}, + {"mlp.down_proj.weight_scale", IN_MLP_DOWN_SCALE_SHARED_EXPERT}, + + // shared expert + {"mlp.shared_experts.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + {"mlp.shared_experts.gate_proj.weight_offset", + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, + {"mlp.shared_experts.gate_proj.weight_scale", + IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, + + {"mlp.shared_experts.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, + {"mlp.shared_experts.up_proj.weight_offset", + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, + {"mlp.shared_experts.up_proj.weight_scale", + IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, + + {"mlp.shared_experts.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, + {"mlp.shared_experts.down_proj.weight_offset", + IN_MLP_DOWN_OFFSET_SHARED_EXPERT}, + {"mlp.shared_experts.down_proj.weight_scale", + IN_MLP_DOWN_SCALE_SHARED_EXPERT}, + + // MoE Gate + {"mlp.gate.weight", BLOCK_SPARSE_MOE_GATE_WEIGHT}, + {"mlp.gate.e_score_correction_bias", BLOCK_SPARSE_MOE_GATE_BIAS}, + + {"gate_proj.weight", IN_MLP_GATEUP_WEIGHT}, + {"gate_proj.weight_offset", IN_MLP_GATEUP_OFFSET}, + {"gate_proj.weight_scale", IN_MLP_GATEUP_SCALE}, + {"up_proj.weight", IN_MLP_GATEUP_WEIGHT}, + {"up_proj.weight_offset", IN_MLP_GATEUP_OFFSET}, + {"up_proj.weight_scale", IN_MLP_GATEUP_SCALE}, + + {"down_proj.weight", IN_MLP_DOWN_WEIGHT}, + {"down_proj.weight_offset", IN_MLP_DOWN_OFFSET}, + {"down_proj.weight_scale", IN_MLP_DOWN_SCALE}, +}; + +static const std::unordered_map> + SPECIAL_MULTI_ASSIGN_W8A8 = { + {"input_layernorm.weight", + {IN_INPUT_NORM_WEIGHT, IN_INPUT_NORM_NEW_WEIGHT}}, + {"post_attention_layernorm.weight", + {IN_POST_ATTN_NORM_WEIGHT, IN_POST_ATTN_NORM_NEW_WEIGHT}}, +}; + +static const std::map WEIGHT_SHARD = { + {IN_QKV_WEIGHT_0, 0}, + {IN_QKV_BIAS_0, 0}, + {IN_QKV_WEIGHT_1, 0}, + {IN_QKV_BIAS_1, 0}, + {IN_QKV_WEIGHT_2, 0}, + {IN_QKV_BIAS_2, 0}, + {IN_QKV_DENSE_WEIGHT, 1}, + {IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, 0}, + {IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, 1}, + {IN_MLP_GATEUP_WEIGHT, 0}, + {IN_MLP_DOWN_WEIGHT, 1}, +}; + +static const std::map WEIGHT_SHARD_W8A8 = { + {IN_QKV_WEIGHT_0, 0}, + {IN_QKV_BIAS_0, 0}, + {IN_QKV_DESCALE_0, 0}, + {IN_QKV_WEIGHT_1, 0}, + {IN_QKV_BIAS_1, 0}, + {IN_QKV_DESCALE_1, 0}, + {IN_QKV_WEIGHT_2, 0}, + {IN_QKV_BIAS_2, 0}, + {IN_QKV_DESCALE_2, 0}, + {IN_QKV_DENSE_WEIGHT, 1}, + {IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, 0}, + {IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, 0}, + {IN_MLP_GATEUP_SCALE_SHARED_EXPERT, 0}, + {IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, 1}, + {IN_MLP_GATEUP_WEIGHT, 0}, + {IN_MLP_GATEUP_OFFSET, 0}, + {IN_MLP_GATEUP_SCALE, 0}, + {IN_MLP_DOWN_WEIGHT, 1}, +}; + +Glm4MoeDecoderLoader::Glm4MoeDecoderLoader( + uint64_t weight_count, + const ModelContext& context, + int32_t layer_id, + int32_t prefill_param_firstKDenseReplace) + : BaseLoader(weight_count, context), + layer_id_(layer_id), + prefill_param_firstKDenseReplace_(prefill_param_firstKDenseReplace) { + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + auto options = context.get_tensor_options(); + + tensor_placeholder_ = torch::zeros({1}).to(options); + + if (model_args.use_qk_norm()) { + weight_count_ = weight_count = 70; + WEIGHT_MAPPING_W8A8["self_attn.q_norm.weight"] = Q_NORM_WEIGHT; + WEIGHT_MAPPING_W8A8["self_attn.k_norm.weight"] = K_NORM_WEIGHT; + WEIGHT_MAPPING["self_attn.q_norm.weight"] = Q_NORM_WEIGHT; + WEIGHT_MAPPING["self_attn.k_norm.weight"] = K_NORM_WEIGHT; + } + + at_weight_tensors_.resize(weight_count_); + + num_experts_ = model_args.num_experts(); + ep_size_ = parallel_args.ep_size(); + ep_local_tp_size_ = parallel_args.world_size() / ep_size_; + CHECK_EQ(parallel_args.world_size(), ep_size_ * ep_local_tp_size_); + ep_local_tp_rank_ = parallel_args.rank() % ep_local_tp_size_; + num_experts_per_partition_ = model_args.num_experts() / ep_size_; + ep_rank_ = parallel_args.rank() / ep_local_tp_size_; + start_expert_id_ = ep_rank_ * num_experts_per_partition_; + end_expert_id_ = start_expert_id_ + num_experts_per_partition_ - 1; + + dp_size_ = parallel_args.dp_size(); + dp_local_tp_size_ = parallel_args.world_size() / dp_size_; + CHECK_EQ(parallel_args.world_size(), dp_size_ * dp_local_tp_size_); + dp_local_tp_rank_ = parallel_args.rank() % dp_local_tp_size_; + + n_kv_heads_ = static_cast(model_args.n_kv_heads().value()); +} + +void Glm4MoeDecoderLoader::resize_experts_weights(int num_of_device_experts) { + experts_weights_["gate_proj.weight"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight"] = + std::vector(num_of_device_experts); + if (quantize_type_.compare("w8a8_dynamic") == 0) { + experts_weights_["gate_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["gate_proj.weight_scale"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight_scale"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight_scale"] = + std::vector(num_of_device_experts); + } +} + +void Glm4MoeDecoderLoader::load_state_dict(const StateDict& state_dict) { + for (const auto& [name, tensor] : state_dict) { + bool is_sharded = false; + int index = 0; + + if (absl::StartsWith(name, "mlp.experts")) { + process_expert_weights(state_dict, name, tensor); + continue; + } + if (absl::StartsWith(name, "mlp.shared_experts")) { + process_shared_expert_weights(state_dict, name, tensor); + continue; + } + if (absl::StartsWith(name, "mlp") && !absl::StrContains(name, "gate.")) { + process_mlp_common_weights(state_dict, name, tensor); + continue; + } + + process_general_weights(state_dict, name, tensor); + } +} + +void Glm4MoeDecoderLoader::verify_loaded_weights() const { + for (const auto& [name, index] : WEIGHT_MAPPING) { + if (name == "down_proj.weight" || name == "gate_proj.weight" || + name == "up_proj.weight" || name == "mlp.gate.weight" || + name == "mlp.gate.e_score_correction_bias") { + continue; + } + CHECK(at_weight_tensors_[index].sizes() != std::vector({0})) + << layer_id_ << "-weight is not loaded for " << name; + } +} + +void Glm4MoeDecoderLoader::merge_loaded_weights() { + merge_shared_experts_weights(); + if (layer_id_ >= prefill_param_firstKDenseReplace_) { + merge_experts_weights(); + } + at_weight_tensors_[IN_QKV_WEIGHT_0] = + torch::cat({at_weight_tensors_[IN_QKV_WEIGHT_0], + at_weight_tensors_[IN_QKV_WEIGHT_1], + at_weight_tensors_[IN_QKV_WEIGHT_2]}, + 0) + .contiguous(); + at_weight_tensors_[IN_QKV_WEIGHT_1] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_WEIGHT_2] = + torch::zeros({1}, torch::kFloat16).to(device_); + + at_weight_tensors_[IN_QKV_BIAS_0] = + at_weight_tensors_[IN_QKV_BIAS_0].squeeze(); + at_weight_tensors_[IN_QKV_BIAS_1] = + at_weight_tensors_[IN_QKV_BIAS_1].squeeze(); + at_weight_tensors_[IN_QKV_BIAS_2] = + at_weight_tensors_[IN_QKV_BIAS_2].squeeze(); + + at_weight_tensors_[IN_QKV_BIAS_0] = + torch::cat({at_weight_tensors_[IN_QKV_BIAS_0], + at_weight_tensors_[IN_QKV_BIAS_1], + at_weight_tensors_[IN_QKV_BIAS_2]}, + 0) + .contiguous(); + at_weight_tensors_[IN_QKV_BIAS_1] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_BIAS_2] = + torch::zeros({1}, torch::kFloat16).to(device_); + + if (quantize_type_.compare("w8a8_dynamic") == 0) { + at_weight_tensors_[IN_QKV_DESCALE_0] = + at_weight_tensors_[IN_QKV_DESCALE_0].squeeze(); + at_weight_tensors_[IN_QKV_DESCALE_1] = + at_weight_tensors_[IN_QKV_DESCALE_1].squeeze(); + at_weight_tensors_[IN_QKV_DESCALE_2] = + at_weight_tensors_[IN_QKV_DESCALE_2].squeeze(); + + at_weight_tensors_[IN_QKV_DESCALE_0] = + torch::cat({at_weight_tensors_[IN_QKV_DESCALE_0], + at_weight_tensors_[IN_QKV_DESCALE_1], + at_weight_tensors_[IN_QKV_DESCALE_2]}, + 0) + .contiguous(); + + at_weight_tensors_[IN_QKV_DESCALE_1] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_DESCALE_2] = + torch::zeros({1}, torch::kFloat16).to(device_); + + at_weight_tensors_[IN_QKV_DENSE_BIAS] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_DENSE_DESCALE] = + torch::zeros({1}, torch::kFloat16).to(device_); + + at_weight_tensors_[IN_QKV_OFFSET_0] = + at_weight_tensors_[IN_QKV_OFFSET_0].to(torch::kInt8).to(device_); + at_weight_tensors_[IN_QKV_OFFSET_1] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_OFFSET_2] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_DENSE_OFFSET] = + at_weight_tensors_[IN_QKV_DENSE_OFFSET].contiguous().view(-1); + + at_weight_tensors_[IN_QKV_SCALE_1] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_SCALE_2] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_DENSE_SCALE] = + at_weight_tensors_[IN_QKV_DENSE_SCALE].contiguous().view(-1); + } +} + +void Glm4MoeDecoderLoader::process_expert_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + int expert_index = extract_expert_index(name); + if (expert_index < start_expert_id_ || expert_index > end_expert_id_) { + return; + } + + const std::string suffix = extract_endswith(name); + const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_MAPPING_W8A8 + : WEIGHT_MAPPING; + const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_SHARD_W8A8 + : WEIGHT_SHARD; + const int index = get_mapped_index(suffix, weight_mapping); + const int local_index = expert_index % num_experts_per_partition_; + const bool is_sharded = shard_map.count(index); + + std::lock_guard lock(experts_mutex_); + torch::Tensor tmp_tensor = is_sharded + ? get_sharded_tensor(state_dict, + name, + shard_map.at(index), + ep_local_tp_rank_, + ep_local_tp_size_) + : tensor; + + experts_weights_[suffix][local_index] = tmp_tensor.clone(); +} + +void Glm4MoeDecoderLoader::process_shared_expert_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + torch::Tensor tmp_tensor; + const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_MAPPING_W8A8 + : WEIGHT_MAPPING; + const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_SHARD_W8A8 + : WEIGHT_SHARD; + std::lock_guard lock(shared_experts_mutex_); + const int index = get_mapped_index(name, weight_mapping); + if (index == -1) { + return; + } + + const bool is_sharded = shard_map.count(index); + tmp_tensor = is_sharded + ? get_sharded_tensor(state_dict, name, shard_map.at(index)) + .to(device_) + : tensor.to(device_); + + if (absl::StrContains(name, "down_proj")) { + at_weight_tensors_[index] = tmp_tensor; + } else { + shared_experts_weights_[name] = tmp_tensor; + } +} + +void Glm4MoeDecoderLoader::process_mlp_common_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_MAPPING_W8A8 + : WEIGHT_MAPPING; + const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_SHARD_W8A8 + : WEIGHT_SHARD; + const int index = get_mapped_index(name, weight_mapping); + const bool is_sharded = shard_map.count(index); + + std::lock_guard lock(shared_experts_mutex_); + + torch::Tensor tmp_tensor = is_sharded + ? get_sharded_tensor(state_dict, + name, + shard_map.at(index), + dp_local_tp_rank_, + dp_local_tp_size_) + .to(device_) + : tensor.to(device_); + if (absl::StrContains(name, "down_proj")) { + at_weight_tensors_[index] = tmp_tensor; + } else { + shared_experts_weights_[name] = tmp_tensor; + } +} + +void Glm4MoeDecoderLoader::process_general_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_MAPPING_W8A8 + : WEIGHT_MAPPING; + const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_SHARD_W8A8 + : WEIGHT_SHARD; + + if (weight_mapping.find(name) == weight_mapping.end()) { + return; + } + + const int index = get_mapped_index(name, weight_mapping); + const bool is_sharded = shard_map.count(index); + torch::Tensor tmp_tensor; + int32_t tp_rank = dp_local_tp_rank_; + int32_t tp_size = dp_local_tp_size_; + if (index == IN_QKV_WEIGHT_1 || index == IN_QKV_WEIGHT_2 || + index == IN_QKV_BIAS_1 || index == IN_QKV_BIAS_2 || + index == IN_QKV_DESCALE_1 || index == IN_QKV_DESCALE_2) { + if (n_kv_heads_ < dp_local_tp_size_) { + int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_); + tp_rank = tp_rank / repeat_times; + tp_size = n_kv_heads_; + } + } + if (is_sharded) { + tmp_tensor = get_sharded_tensor( + state_dict, name, shard_map.at(index), tp_rank, tp_size) + .to(device_); + } else { + tmp_tensor = tensor.to(device_); + } + if (index == BLOCK_SPARSE_MOE_GATE_BIAS) { + auto min_val = tmp_tensor.min(); + tmp_tensor = tmp_tensor - min_val; + } + correct_tensor_dtype(tmp_tensor, name); + if (quantize_type_.compare("w8a8_dynamic") == 0) { + auto it = SPECIAL_MULTI_ASSIGN_W8A8.find(name); + if (it != SPECIAL_MULTI_ASSIGN_W8A8.end()) { + for (int idx : it->second) { + at_weight_tensors_[idx] = tmp_tensor; + } + return; + } + } + at_weight_tensors_[index] = tmp_tensor; +} + +torch::Tensor Glm4MoeDecoderLoader::get_sharded_tensor( + const StateDict& state_dict, + const std::string& name, + int dim) { + if (parallel_args_.world_size() > 1) { + return state_dict.get_sharded_tensor( + name, dim, parallel_args_.rank(), parallel_args_.world_size()); + } else { + return state_dict.get_tensor(name); + } +} + +torch::Tensor Glm4MoeDecoderLoader::get_sharded_tensor( + const StateDict& state_dict, + const std::string& name, + int dim, + int loacal_tp_rank, + int local_tp_size) { + if (local_tp_size > 1) { + return state_dict.get_sharded_tensor( + name, dim, loacal_tp_rank, local_tp_size); + } else { + return state_dict.get_tensor(name); + } +} + +std::string Glm4MoeDecoderLoader::extract_endswith(const std::string& input) { + std::vector parts; + std::stringstream ss(input); + std::string part; + while (std::getline(ss, part, '.')) { + parts.push_back(part); + } + if (parts.size() < 2) { + return ""; + } + std::string result = parts[parts.size() - 2] + "." + parts[parts.size() - 1]; + + return result; +} + +int Glm4MoeDecoderLoader::extract_expert_index(const std::string& name) { + std::string prefix = "experts."; + size_t pos = name.find(prefix); + if (pos != std::string::npos) { + pos += prefix.length(); + size_t end_pos = pos; + while (end_pos < name.length() && std::isdigit(name[end_pos])) { + ++end_pos; + } + if (end_pos > pos) { + return std::stoi(name.substr(pos, end_pos - pos)); + } + } + + return -1; +} + +void Glm4MoeDecoderLoader::merge_shared_experts_weights() { + auto merge_and_clear = [this](int index, + torch::Tensor& shared_experts_gate, + torch::Tensor& shared_experts_up) { + at_weight_tensors_[index] = + torch::cat({shared_experts_gate, shared_experts_up}, 0) + .to(device_) + .contiguous(); + shared_experts_gate = tensor_placeholder_; + shared_experts_up = tensor_placeholder_; + }; + + if (layer_id_ >= prefill_param_firstKDenseReplace_) { + merge_and_clear( + IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, + shared_experts_weights_["mlp.shared_experts.gate_proj.weight"], + shared_experts_weights_["mlp.shared_experts.up_proj.weight"]); + if (quantize_type_ == "w8a8_dynamic") { + merge_and_clear( + IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, + shared_experts_weights_["mlp.shared_experts.gate_proj.weight_offset"], + shared_experts_weights_["mlp.shared_experts.up_proj.weight_offset"]); + merge_and_clear( + IN_MLP_GATEUP_SCALE_SHARED_EXPERT, + shared_experts_weights_["mlp.shared_experts.gate_proj.weight_scale"], + shared_experts_weights_["mlp.shared_experts.up_proj.weight_scale"]); + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT] = + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT].squeeze(); + at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT] = + at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT].squeeze(); + at_weight_tensors_[IN_MLP_DOWN_OFFSET_SHARED_EXPERT] = + at_weight_tensors_[IN_MLP_DOWN_OFFSET_SHARED_EXPERT].squeeze(); + at_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT] = + at_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT].squeeze(); + } + } else { + merge_and_clear(IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, + shared_experts_weights_["mlp.gate_proj.weight"], + shared_experts_weights_["mlp.up_proj.weight"]); + if (quantize_type_ == "w8a8_dynamic") { + merge_and_clear(IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, + shared_experts_weights_["mlp.gate_proj.weight_offset"], + shared_experts_weights_["mlp.up_proj.weight_offset"]); + merge_and_clear(IN_MLP_GATEUP_SCALE_SHARED_EXPERT, + shared_experts_weights_["mlp.gate_proj.weight_scale"], + shared_experts_weights_["mlp.up_proj.weight_scale"]); + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT] = + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT].squeeze(); + at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT] = + at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT].squeeze(); + } + } +} + +void Glm4MoeDecoderLoader::merge_experts_weights() { + try { + torch::Tensor mlp_gateup_weight; + if (quantize_type_.compare("w8a8_dynamic") == 0) { + mlp_gateup_weight = + merge_experts_weights(experts_weights_["gate_proj.weight"], + experts_weights_["up_proj.weight"], + /*transpose=*/true); + + at_weight_tensors_[IN_MLP_GATEUP_OFFSET] = + merge_experts_weights(experts_weights_["gate_proj.weight_offset"], + experts_weights_["up_proj.weight_offset"]); + at_weight_tensors_[IN_MLP_GATEUP_SCALE] = + merge_experts_weights(experts_weights_["gate_proj.weight_scale"], + experts_weights_["up_proj.weight_scale"]); + at_weight_tensors_[IN_MLP_GATEUP_WEIGHT] = + at_npu::native::npu_format_cast(mlp_gateup_weight, 29); + } else { + mlp_gateup_weight = + merge_experts_weights(experts_weights_["gate_proj.weight"], + experts_weights_["up_proj.weight"], + /*transpose=*/false); + at_weight_tensors_[IN_MLP_GATEUP_WEIGHT] = + at_npu::native::npu_format_cast(mlp_gateup_weight, 2).contiguous(); + } + } catch (const std::exception& e) { + LOG(ERROR) << "[ERROR] Exception in gateup weight processing: " << e.what(); + throw; + } + + if (experts_weights_.count("down_proj.weight") > 0) { + auto& down_weight = experts_weights_["down_proj.weight"]; + } + + try { + torch::Tensor mlp_down_weight = + merge_experts_weights(experts_weights_["down_proj.weight"], + /*transpose=*/false); + + at_weight_tensors_[IN_MLP_DOWN_WEIGHT] = + at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); + + if (quantize_type_.compare("w8a8_dynamic") == 0) { + at_weight_tensors_[IN_MLP_DOWN_OFFSET] = + merge_experts_weights(experts_weights_["down_proj.weight_offset"]); + at_weight_tensors_[IN_MLP_DOWN_SCALE] = + merge_experts_weights(experts_weights_["down_proj.weight_scale"]); + } + } catch (const std::exception& e) { + LOG(ERROR) << "[ERROR] Exception in down weight processing: " << e.what(); + throw; + } +} + +torch::Tensor Glm4MoeDecoderLoader::merge_experts_weights( + std::vector& experts, + bool transpose) { + torch::Tensor merged_tensor = torch::stack(experts, 0).to(device_); + if (transpose) { + merged_tensor = merged_tensor.transpose(1, 2); + } + merged_tensor = merged_tensor.contiguous(); + experts.clear(); + + return merged_tensor; +} + +torch::Tensor Glm4MoeDecoderLoader::merge_experts_weights( + std::vector& experts_gate, + std::vector& experts_up, + bool transpose) { + for (size_t i = 0; i < experts_up.size(); ++i) { + experts_gate[i] = torch::cat({experts_gate[i], experts_up[i]}, 0); + } + torch::Tensor merged_tensor = torch::stack(experts_gate, 0).to(device_); + if (transpose) { + merged_tensor = merged_tensor.transpose(1, 2); + } + merged_tensor = merged_tensor.contiguous(); + experts_gate.clear(); + experts_up.clear(); + + return merged_tensor; +} + +int Glm4MoeDecoderLoader::get_mapped_index( + const std::string& name, + const std::unordered_map& mapping) { + const auto it = mapping.find(name); + if (it == mapping.end()) { + LOG(ERROR) << "Missing mapping for: " << name; + return -1; + } + return it->second; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/glm4_moe_decoder_loader.h b/xllm/core/layers/npu/loader/glm4_moe_decoder_loader.h new file mode 100644 index 000000000..ec4e22917 --- /dev/null +++ b/xllm/core/layers/npu/loader/glm4_moe_decoder_loader.h @@ -0,0 +1,121 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include + +#include "base_loader.h" +#include "framework/model/model_args.h" +#include "framework/model/npu_dp_ep_padding.h" +#include "framework/quant_args.h" +#include "framework/state_dict/state_dict.h" +#include "xllm_kernels/models/glm/layer/moe_decoder_layer.h" + +namespace xllm { +namespace layer { + +class Glm4MoeDecoderLoader : public BaseLoader { + public: + Glm4MoeDecoderLoader(uint64_t weight_count, + const ModelContext& context, + int32_t layer_id, + int32_t prefill_param_firstKDenseReplace); + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights() const override; + void merge_loaded_weights() override; + + void resize_experts_weights(int num_of_device_experts) override; + + int32_t layer_id_; + int32_t prefill_param_firstKDenseReplace_; + + int32_t ep_size_; + int32_t num_experts_; + int32_t num_experts_per_partition_; + int32_t ep_local_tp_size_; + int32_t ep_local_tp_rank_; + int32_t start_expert_id_; + int32_t end_expert_id_; + int32_t ep_rank_; + int32_t n_kv_heads_; + + int32_t dp_size_; + int32_t dp_local_tp_size_; + int32_t dp_rank_; + int32_t dp_local_tp_rank_; + + torch::Tensor tensor_placeholder_; + + std::unordered_map shared_experts_weights_; + std::unordered_map> experts_weights_; + + std::mutex shared_experts_mutex_; + std::mutex experts_mutex_; + + torch::ScalarType dtype_; + + void process_expert_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + void process_shared_expert_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + void process_mlp_common_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + void process_general_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + torch::Tensor get_sharded_tensor(const StateDict& state_dict, + const std::string& name, + int dim); + torch::Tensor get_sharded_tensor(const StateDict& state_dict, + const std::string& name, + int dim, + int local_tp_rank, + int local_tp_size); + + std::string extract_endswith(const std::string& input); + + int extract_expert_index(const std::string& name); + + void merge_shared_experts_weights(); + + void merge_experts_weights(); + + torch::Tensor merge_experts_weights(std::vector& experts, + bool transpose = false); + + torch::Tensor merge_experts_weights(std::vector& experts_up, + std::vector& experts_gate, + bool transpose = false); + + // int64_t init_layer(); + + int get_mapped_index(const std::string& name, + const std::unordered_map& mapping); +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/llama_decoder_loader.cpp b/xllm/core/layers/npu/loader/llama_decoder_loader.cpp new file mode 100644 index 000000000..50051a500 --- /dev/null +++ b/xllm/core/layers/npu/loader/llama_decoder_loader.cpp @@ -0,0 +1,152 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llama_decoder_loader.h" + +namespace xllm { +namespace layer { + +enum DecoderLayerTensorId : int { + + IN_NORM_WEIGHT = 0, // weight + IN_NORM_BIAS, // bias + IN_NORM_NEW_WEIGHT, // new weight + IN_NORM_NEW_BIAS, // new bias + + IN_Q_WEIGHT, // weight + IN_Q_BIAS, // bias + IN_Q_DEQSCALE, // deq_scale + IN_Q_OFFSET, // offset + IN_Q_SCALE, // scale + IN_Q_COMPRESS_IDX, + + IN_K_WEIGHT, // weight + IN_K_BIAS, // bias + IN_K_DEQSCALE, // deq_scale + IN_K_OFFSET, // offset + IN_K_SCALE, // scale + IN_K_COMPRESS_IDX, + + IN_V_WEIGHT, // weight + IN_V_BIAS, // bias + IN_V_DEQSCALE, // deq_scale + IN_V_OFFSET, // offset + IN_V_SCALE, // scale + IN_V_COMPRESS_IDX, + + IN_ATTENTION_OUT_WEIGHT, // weight + IN_ATTENTION_OUT_BIAS, // bias + IN_ATTENTION_OUT_DEQSCALE, // deq_scale + IN_ATTENTION_OUT_OFFSET, // offset + IN_ATTENTION_OUT_SCALE, // scale + IN_ATTENTION_OUT_COMPRESS_IDX, + + IN_SELFOUT_NORM_WEIGHT, // weight + IN_SELFOUT_NORM_BIAS, // bias + IN_SELFOUT_NORM_NEW_WEIGHT, // new weight + IN_SELFOUT_NORM_NEW_BIAS, // new bias + + IN_MLP_W2_WEIGHT, // weight + IN_MLP_W2_BIAS, // bias + IN_MLP_W2_DEQSCALE, // deq_scale + IN_MLP_W2_OFFSET, // offset + IN_MLP_W2_SCALE, // scale + IN_MLP_W2_COMPRESS_IDX, + + IN_MLP_W1_WEIGHT, // weight + IN_MLP_W1_BIAS, // bias + IN_MLP_W1_DEQSCALE, // deq_scale + IN_MLP_W1_OFFSET, // offset + IN_MLP_W1_SCALE, // scale + IN_MLP_W1_COMPRESS_IDX, + + IN_MLP_CPROJ_WEIGHT, // weight + IN_MLP_CPROJ_BIAS, // bias + IN_MLP_CPROJ_DEQSCALE, // deq_scale + IN_MLP_CPROJ_OFFSET, // offset + IN_MLP_CPROJ_SCALE, // scale + IN_MLP_CPROJ_COMPRESS_IDX, +}; + +static const std::unordered_map WEIGHT_MAPPING = { + {"input_layernorm.weight", IN_NORM_WEIGHT}, + {"self_attn.q_proj.weight", IN_Q_WEIGHT}, + {"self_attn.k_proj.weight", IN_K_WEIGHT}, + {"self_attn.v_proj.weight", IN_V_WEIGHT}, + {"self_attn.o_proj.weight", IN_ATTENTION_OUT_WEIGHT}, + {"post_attention_layernorm.weight", IN_SELFOUT_NORM_WEIGHT}, + {"mlp.gate_proj.weight", IN_MLP_W2_WEIGHT}, + {"mlp.up_proj.weight", IN_MLP_W1_WEIGHT}, + {"mlp.down_proj.weight", IN_MLP_CPROJ_WEIGHT}, +}; + +static std::map WEIGHT_SHARD = {{IN_Q_WEIGHT, 0}, + {IN_K_WEIGHT, 0}, + {IN_V_WEIGHT, 0}, + {IN_ATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_W2_WEIGHT, 0}, + {IN_MLP_W1_WEIGHT, 0}, + {IN_MLP_CPROJ_WEIGHT, 1}}; + +LlamaDecoderLoader::LlamaDecoderLoader(uint64_t weight_count, + const ModelContext& context) + : BaseLoader(weight_count, context) { + at_weight_tensors_.resize(weight_count); + + auto options = context.get_tensor_options(); + dtype_ = torch::typeMetaToScalarType(options.dtype()); + + for (int i = 0; i < weight_count; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } +} + +void LlamaDecoderLoader::verify_loaded_weights() const { + for (const auto& [name, index] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +void LlamaDecoderLoader::merge_loaded_weights() { + auto new_q_weight = torch::cat({at_weight_tensors_[IN_Q_WEIGHT], + at_weight_tensors_[IN_K_WEIGHT], + at_weight_tensors_[IN_V_WEIGHT]}, + 0); + at_weight_tensors_[IN_Q_WEIGHT] = new_q_weight; + + at_weight_tensors_[IN_K_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_V_WEIGHT] = torch::zeros({1}).to(device_); + + auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], + at_weight_tensors_[IN_MLP_W1_WEIGHT]}, + 0); + at_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight; + + at_weight_tensors_[IN_MLP_W1_WEIGHT] = torch::zeros({1}).to(device_); +} + +void LlamaDecoderLoader::load_state_dict(const StateDict& state_dict) { + for (const auto& [name, index] : WEIGHT_MAPPING) { + if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + } +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/llama_decoder_loader.h b/xllm/core/layers/npu/loader/llama_decoder_loader.h new file mode 100644 index 000000000..57f086128 --- /dev/null +++ b/xllm/core/layers/npu/loader/llama_decoder_loader.h @@ -0,0 +1,39 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "base_loader.h" + +namespace xllm { +namespace layer { + +class LlamaDecoderLoader : public BaseLoader { + public: + LlamaDecoderLoader(uint64_t weight_count, const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights() const override; + void merge_loaded_weights() override; + + bool enableAddNorm_; + int rank_id_; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/lm_head_loader.cpp b/xllm/core/layers/npu/loader/lm_head_loader.cpp new file mode 100644 index 000000000..99971f91f --- /dev/null +++ b/xllm/core/layers/npu/loader/lm_head_loader.cpp @@ -0,0 +1,42 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "lm_head_loader.h" + +namespace xllm { +namespace layer { + +LmHeadLoader::LmHeadLoader(uint64_t weight_count, const ModelContext& context) + : BaseLoader(weight_count, context) { + auto options = context.get_tensor_options(); + at_weight_tensors_[0] = torch::zeros({1}).to(options); +} + +void LmHeadLoader::load_state_dict(const StateDict& state_dict) { + if (dp_size_ > 1) { + set_weight( + state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_); + } else { + set_weight(state_dict, "weight", 0, 0); + } +} + +void LmHeadLoader::verify_loaded_weights(const std::string& weight_str) const { + CHECK(at_weight_tensors_[0].sizes() != std::vector({1})) + << "final lm_head weight is not loaded for " << weight_str; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/lm_head_loader.h b/xllm/core/layers/npu/loader/lm_head_loader.h new file mode 100644 index 000000000..75d5a2187 --- /dev/null +++ b/xllm/core/layers/npu/loader/lm_head_loader.h @@ -0,0 +1,28 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "base_loader.h" + +namespace xllm { +namespace layer { +class LmHeadLoader : public BaseLoader { + public: + LmHeadLoader(uint64_t weight_count, const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights(const std::string& weight_str) const override; +}; +} // namespace layer +} // namespace xllm diff --git a/xllm/core/layers/npu/loader/qwen2_decoder_loader.cpp b/xllm/core/layers/npu/loader/qwen2_decoder_loader.cpp new file mode 100644 index 000000000..55412ce86 --- /dev/null +++ b/xllm/core/layers/npu/loader/qwen2_decoder_loader.cpp @@ -0,0 +1,247 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "qwen2_decoder_loader.h" + +namespace xllm { +namespace layer { + +static std::vector> WEIGHT_MAPPING = { + {IN_NORM_WEIGHT, "input_layernorm.weight"}, + {IN_Q_WEIGHT, "self_attn.q_proj.weight"}, + {IN_Q_BIAS, "self_attn.q_proj.bias"}, + {IN_K_WEIGHT, "self_attn.k_proj.weight"}, + {IN_K_BIAS, "self_attn.k_proj.bias"}, + {IN_V_WEIGHT, "self_attn.v_proj.weight"}, + {IN_V_BIAS, "self_attn.v_proj.bias"}, + {IN_ATTENTION_OUT_WEIGHT, "self_attn.o_proj.weight"}, + {IN_SELFOUT_NORM_WEIGHT, "post_attention_layernorm.weight"}, + {IN_MLP_W2_WEIGHT, "mlp.gate_proj.weight"}, + {IN_MLP_W1_WEIGHT, "mlp.up_proj.weight"}, + {IN_MLP_CPROJ_WEIGHT, "mlp.down_proj.weight"}}; + +static std::vector> WEIGHT_MAPPING_W8A8 = { + {IN_NORM_WEIGHT, "input_layernorm.weight"}, + {IN_Q_WEIGHT, "self_attn.q_proj.weight"}, + {IN_Q_BIAS, "self_attn.q_proj.quant_bias"}, + {IN_Q_DEQSCALE, "self_attn.q_proj.deq_scale"}, + {IN_Q_OFFSET, "self_attn.q_proj.input_offset"}, + {IN_Q_SCALE, "self_attn.q_proj.input_scale"}, + {IN_K_WEIGHT, "self_attn.k_proj.weight"}, + {IN_K_BIAS, "self_attn.k_proj.quant_bias"}, + {IN_K_DEQSCALE, "self_attn.k_proj.deq_scale"}, + {IN_K_OFFSET, "self_attn.k_proj.input_offset"}, + {IN_K_SCALE, "self_attn.k_proj.input_scale"}, + {IN_V_WEIGHT, "self_attn.v_proj.weight"}, + {IN_V_BIAS, "self_attn.v_proj.quant_bias"}, + {IN_V_DEQSCALE, "self_attn.v_proj.deq_scale"}, + {IN_V_OFFSET, "self_attn.v_proj.input_offset"}, + {IN_V_SCALE, "self_attn.v_proj.input_scale"}, + {IN_ATTENTION_OUT_WEIGHT, "self_attn.o_proj.weight"}, + {IN_ATTENTION_OUT_BIAS, "self_attn.o_proj.quant_bias"}, + {IN_ATTENTION_OUT_DEQSCALE, "self_attn.o_proj.deq_scale"}, + {IN_ATTENTION_OUT_OFFSET, "self_attn.o_proj.input_offset"}, + {IN_ATTENTION_OUT_SCALE, "self_attn.o_proj.input_scale"}, + {IN_SELFOUT_NORM_WEIGHT, "post_attention_layernorm.weight"}, + {IN_MLP_W2_WEIGHT, "mlp.gate_proj.weight"}, + {IN_MLP_W2_BIAS, "mlp.gate_proj.quant_bias"}, + {IN_MLP_W2_DEQSCALE, "mlp.gate_proj.deq_scale"}, + {IN_MLP_W2_OFFSET, "mlp.gate_proj.input_offset"}, + {IN_MLP_W2_SCALE, "mlp.gate_proj.input_scale"}, + {IN_MLP_W1_WEIGHT, "mlp.up_proj.weight"}, + {IN_MLP_W1_BIAS, "mlp.up_proj.quant_bias"}, + {IN_MLP_W1_DEQSCALE, "mlp.up_proj.deq_scale"}, + {IN_MLP_W1_OFFSET, "mlp.up_proj.input_offset"}, + {IN_MLP_W1_SCALE, "mlp.up_proj.input_scale"}, + {IN_MLP_CPROJ_WEIGHT, "mlp.down_proj.weight"}}; + +static std::map WEIGHT_SHARD = {{IN_Q_WEIGHT, 0}, + {IN_Q_BIAS, 0}, + {IN_K_WEIGHT, 0}, + {IN_K_BIAS, 0}, + {IN_V_WEIGHT, 0}, + {IN_V_BIAS, 0}, + {IN_ATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_W2_WEIGHT, 0}, + {IN_MLP_W1_WEIGHT, 0}, + {IN_MLP_CPROJ_WEIGHT, 1}}; + +static std::map WEIGHT_SHARD_W8A8 = {{IN_Q_WEIGHT, 0}, + {IN_Q_BIAS, 0}, + {IN_Q_DEQSCALE, 0}, + {IN_K_WEIGHT, 0}, + {IN_K_BIAS, 0}, + {IN_K_DEQSCALE, 0}, + {IN_V_WEIGHT, 0}, + {IN_V_BIAS, 0}, + {IN_V_DEQSCALE, 0}, + {IN_ATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_W2_WEIGHT, 0}, + {IN_MLP_W2_BIAS, 0}, + {IN_MLP_W2_DEQSCALE, 0}, + {IN_MLP_W1_WEIGHT, 0}, + {IN_MLP_W1_BIAS, 0}, + {IN_MLP_W1_DEQSCALE, 0}, + {IN_MLP_CPROJ_WEIGHT, 1}}; + +Qwen2DecoderLoader::Qwen2DecoderLoader(uint64_t weight_count, + const ModelContext& context) + : BaseLoader(weight_count, context) { + auto options = context.get_tensor_options(); + device_id_ = options.device().index(); + + for (int i = 0; i < weight_count; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } +} + +void Qwen2DecoderLoader::load_state_dict(const StateDict& state_dict) { + if (quantize_type_ == "w8a8") { + for (const auto& [index, name] : WEIGHT_MAPPING_W8A8) { + if (WEIGHT_SHARD_W8A8.find(index) != WEIGHT_SHARD_W8A8.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD_W8A8[index]); + } else { + set_weight(state_dict, name, index); + } + } + at_weight_tensors_[IN_NORM_BIAS] = + torch::zeros(at_weight_tensors_[IN_NORM_WEIGHT].sizes(), + at_weight_tensors_[IN_NORM_WEIGHT].options()) + .to(device_); + + at_weight_tensors_[IN_SELFOUT_NORM_BIAS] = + torch::zeros(at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].sizes(), + at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].options()) + .to(device_); + + return; + } + + for (const auto& [index, name] : WEIGHT_MAPPING) { + if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + } +} + +void Qwen2DecoderLoader::merge_loaded_weights() { + if (quantize_type_ == "w8a8") { + at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE] = + at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE].to(torch::kFloat32); + at_weight_tensors_[IN_Q_DEQSCALE] = + torch::cat({at_weight_tensors_[IN_Q_DEQSCALE], + at_weight_tensors_[IN_K_DEQSCALE], + at_weight_tensors_[IN_V_DEQSCALE]}, + 0) + .to(torch::kFloat32); + at_weight_tensors_[IN_K_DEQSCALE] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_V_DEQSCALE] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_K_OFFSET] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_V_OFFSET] = torch::zeros({1}).to(device_); + + at_weight_tensors_[IN_K_SCALE] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_V_SCALE] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_MLP_W2_BIAS] = + torch::cat({at_weight_tensors_[IN_MLP_W2_BIAS], + at_weight_tensors_[IN_MLP_W1_BIAS]}, + 0); + at_weight_tensors_[IN_MLP_W1_BIAS] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_MLP_W2_DEQSCALE] = + torch::cat({at_weight_tensors_[IN_MLP_W2_DEQSCALE], + at_weight_tensors_[IN_MLP_W1_DEQSCALE]}, + 0) + .to(torch::kFloat32); + at_weight_tensors_[IN_MLP_W1_DEQSCALE] = torch::zeros({1}).to(device_); + + at_weight_tensors_[IN_MLP_W1_OFFSET] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_MLP_W1_SCALE] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_Q_OFFSET] = + at_weight_tensors_[IN_Q_OFFSET].to(torch::kInt8).to(device_); + at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] = + at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] + .to(torch::kInt8) + .to(device_); + at_weight_tensors_[IN_MLP_W2_OFFSET] = + at_weight_tensors_[IN_MLP_W2_OFFSET].to(torch::kInt8).to(device_); + if (device_id_ != 0) { + torch::Tensor original_tensor = at_weight_tensors_[IN_ATTENTION_OUT_BIAS]; + auto shape = original_tensor.sizes(); + auto dtype = original_tensor.dtype(); + auto device = original_tensor.device(); + + at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = torch::zeros( + shape, torch::TensorOptions().dtype(dtype).device(device)); + } + } + + auto new_q_weight = torch::cat({at_weight_tensors_[IN_Q_WEIGHT], + at_weight_tensors_[IN_K_WEIGHT], + at_weight_tensors_[IN_V_WEIGHT]}, + 0); + + at_weight_tensors_[IN_Q_WEIGHT] = new_q_weight; + + at_weight_tensors_[IN_K_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_V_WEIGHT] = torch::zeros({1}).to(device_); + + auto new_q_bias = torch::cat({at_weight_tensors_[IN_Q_BIAS], + at_weight_tensors_[IN_K_BIAS], + at_weight_tensors_[IN_V_BIAS]}, + 0); + at_weight_tensors_[IN_Q_BIAS] = new_q_bias; + + at_weight_tensors_[IN_K_BIAS] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_V_BIAS] = torch::zeros({1}).to(device_); + + TransposeType transpose_type = + check_transpose(at_weight_tensors_[IN_MLP_W2_WEIGHT]); + if (transpose_type == TransposeType::TRANSPOSE) { + auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], + at_weight_tensors_[IN_MLP_W1_WEIGHT]}, + 0); + at_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight.contiguous(); + } else { + auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], + at_weight_tensors_[IN_MLP_W1_WEIGHT]}, + 0) + .transpose(0, 1); + at_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight.contiguous(); + } + + at_weight_tensors_[IN_MLP_W1_WEIGHT] = torch::zeros({1}).to(device_); +} + +TransposeType Qwen2DecoderLoader::check_transpose(torch::Tensor& tensor) { + bool is_k_divisible = tensor.size(1) % 256 == 0; + bool is_n_divisible = tensor.size(0) % 256 == 0; + + if (!is_k_divisible && is_n_divisible) { + return TransposeType::NOT_TRANSPOSE; + } + + return TransposeType::TRANSPOSE; +} + +void Qwen2DecoderLoader::verify_loaded_weights() const { + for (const auto& [index, name] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/qwen2_decoder_loader.h b/xllm/core/layers/npu/loader/qwen2_decoder_loader.h new file mode 100644 index 000000000..1cc04b45e --- /dev/null +++ b/xllm/core/layers/npu/loader/qwen2_decoder_loader.h @@ -0,0 +1,101 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "core/layers/npu/npu_base_layer.h" + +namespace xllm { +namespace layer { + +enum DecoderLayerTensorId : int { + IN_NORM_WEIGHT = 0, // weight + IN_NORM_BIAS = 1, // bias + IN_NORM_NEW_WEIGHT = 2, // new weight + IN_NORM_NEW_BIAS = 3, // new bias + + IN_Q_WEIGHT = 4, // weight + IN_Q_BIAS = 5, // bias + IN_Q_DEQSCALE = 6, // deq_scale + IN_Q_OFFSET = 7, // offset + IN_Q_SCALE = 8, // scale + IN_Q_COMPRESS_IDX = 9, + + IN_K_WEIGHT = 10, // weight + IN_K_BIAS = 11, // bias + IN_K_DEQSCALE = 12, // deq_scale + IN_K_OFFSET = 13, // offset + IN_K_SCALE = 14, // scale + IN_K_COMPRESS_IDX = 15, + + IN_V_WEIGHT = 16, // weight + IN_V_BIAS = 17, // bias + IN_V_DEQSCALE = 18, // deq_scale + IN_V_OFFSET = 19, // offset + IN_V_SCALE = 20, // scale + IN_V_COMPRESS_IDX = 21, + + IN_ATTENTION_OUT_WEIGHT = 22, // weight + IN_ATTENTION_OUT_BIAS = 23, // bias + IN_ATTENTION_OUT_DEQSCALE = 24, // deq_scale + IN_ATTENTION_OUT_OFFSET = 25, // offset + IN_ATTENTION_OUT_SCALE = 26, // scale + IN_ATTENTION_OUT_COMPRESS_IDX = 27, + + IN_SELFOUT_NORM_WEIGHT = 28, // weight + IN_SELFOUT_NORM_BIAS = 29, // bias + IN_SELFOUT_NORM_NEW_WEIGHT = 30, // new weight + IN_SELFOUT_NORM_NEW_BIAS = 31, // new bias + + IN_MLP_W2_WEIGHT = 32, // weight + IN_MLP_W2_BIAS = 33, // bias + IN_MLP_W2_DEQSCALE = 34, // deq_scale + IN_MLP_W2_OFFSET = 35, // offset + IN_MLP_W2_SCALE = 36, // scale + IN_MLP_W2_COMPRESS_IDX = 37, + + IN_MLP_W1_WEIGHT = 38, // weight + IN_MLP_W1_BIAS = 39, // bias + IN_MLP_W1_DEQSCALE = 40, // deq_scale + IN_MLP_W1_OFFSET = 41, // offset + IN_MLP_W1_SCALE = 42, // scale + IN_MLP_W1_COMPRESS_IDX = 43, + + IN_MLP_CPROJ_WEIGHT = 44, // weight + IN_MLP_CPROJ_BIAS = 45, // bias + IN_MLP_CPROJ_DEQSCALE = 46, // deq_scale + IN_MLP_CPROJ_OFFSET = 47, // offset + IN_MLP_CPROJ_SCALE = 48, // scale + IN_MLP_CPROJ_COMPRESS_IDX = 49, +}; + +class Qwen2DecoderLoader : public BaseLoader { + public: + Qwen2DecoderLoader(uint64_t weight_count, const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights() const override; + void merge_loaded_weights() override; + + protected: + int device_id_; + TransposeType check_transpose(torch::Tensor& tensor); +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/qwen2dot5_vision_encoder_loader.cpp b/xllm/core/layers/npu/loader/qwen2dot5_vision_encoder_loader.cpp new file mode 100644 index 000000000..322d239c9 --- /dev/null +++ b/xllm/core/layers/npu/loader/qwen2dot5_vision_encoder_loader.cpp @@ -0,0 +1,303 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#ifdef TORCH_HIGHER_THAN_PTA6 +#include +#include +#else +#include +#include +#endif + +#include + +#include "qwen2dot5_vision_encoder_loader.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/core/npu/NPUException.h" + +namespace xllm { +namespace layer { + +enum VisionEncoderLayerTensorId : int { + IN_INPUT_NORM_WEIGHT = 0, + IN_POST_NORM_WEIGHT, + IN_QKV_WEIGHT, + IN_QKV_BIAS, + IN_WATTENTION_OUT_WEIGHT, + IN_WATTENTION_OUT_BIAS, + IN_MLP_GATE_WEIGHT, + IN_MLP_GATE_BIAS, + IN_MLP_UP_WEIGHT, + IN_MLP_UP_BIAS, + IN_MLP_DOWN_WEIGHT, + IN_MLP_DOWN_BIAS, + IN_VISION_Q_WEIGHT, + IN_VISION_Q_BIAS, + IN_VISION_K_WEIGHT, + IN_VISION_K_BIAS, + IN_VISION_V_WEIGHT, + IN_VISION_V_BIAS +}; + +static std::vector> WEIGHT_MAPPING = { + {IN_INPUT_NORM_WEIGHT, "norm1.weight"}, + {IN_POST_NORM_WEIGHT, "norm2.weight"}, + {IN_QKV_WEIGHT, "qkv.weight"}, + {IN_QKV_BIAS, "qkv.bias"}, + {IN_WATTENTION_OUT_WEIGHT, "attn.proj.weight"}, + {IN_WATTENTION_OUT_BIAS, "attn.proj.bias"}, + {IN_MLP_GATE_WEIGHT, "mlp.gate_proj.weight"}, + {IN_MLP_GATE_BIAS, "mlp.gate_proj.bias"}, + {IN_MLP_UP_WEIGHT, "mlp.up_proj.weight"}, + {IN_MLP_UP_BIAS, "mlp.up_proj.bias"}, + {IN_MLP_DOWN_WEIGHT, "mlp.down_proj.weight"}, + {IN_MLP_DOWN_BIAS, "mlp.down_proj.bias"}, +}; + +// {weight,dim} +static std::map WEIGHT_SHARD = { + {IN_WATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_GATE_WEIGHT, 0}, + {IN_MLP_GATE_BIAS, 0}, + {IN_MLP_UP_WEIGHT, 0}, + {IN_MLP_UP_BIAS, 0}, + {IN_MLP_DOWN_WEIGHT, 1}, +}; + +Qwen2dot5VisionEncoderLoader::Qwen2dot5VisionEncoderLoader( + uint64_t weight_count, + const ModelContext& context, + int64_t numAttentionHeadsPerRank) + : BaseLoader(weight_count, context) { + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + auto options = context.get_tensor_options(); + encode_param_rank = parallel_args.rank(); + encode_param_worldSize = parallel_args.world_size(); + encode_param_numAttentionHeadsPerRank = numAttentionHeadsPerRank; + at_weight_tensors_.resize(weight_count); + dtype_ = torch::typeMetaToScalarType(options.dtype()); + device_id_ = options.device().index(); + at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); + for (int i = 0; i < weight_count; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } +} + +void Qwen2dot5VisionEncoderLoader::load_state_dict( + const StateDict& state_dict) { + for (const auto& [index, name] : WEIGHT_MAPPING) { + if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + } + get_weights_col_packed_qkv(); +} + +// tp spilt weight +void Qwen2dot5VisionEncoderLoader::get_weights_col_packed_qkv() { + int rank = encode_param_rank; + int worldSize = encode_param_worldSize; + // split qkv weight + qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0); + qkv_bias = torch::chunk(at_weight_tensors_[IN_QKV_BIAS], 3, 0); + // weight + at_weight_tensors_[IN_VISION_Q_WEIGHT] = + (qkv_weight[0].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_K_WEIGHT] = + (qkv_weight[1].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_V_WEIGHT] = + (qkv_weight[2].chunk(worldSize, 0))[rank]; + // bias + at_weight_tensors_[IN_VISION_Q_BIAS] = + (qkv_bias[0].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_K_BIAS] = + (qkv_bias[1].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_V_BIAS] = + (qkv_bias[2].chunk(worldSize, 0))[rank]; +} + +void Qwen2dot5VisionEncoderLoader::verify_loaded_weights() const { + for (const auto& [index, name] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +void Qwen2dot5VisionEncoderLoader::merge_loaded_weights() { + // spilt pack qkv weight when enable tp + get_weights_col_packed_qkv(); + if (encode_param_worldSize > 1) { + // merge qkv weight + auto new_qkv_weight = torch::cat({at_weight_tensors_[IN_VISION_Q_WEIGHT], + at_weight_tensors_[IN_VISION_K_WEIGHT], + at_weight_tensors_[IN_VISION_V_WEIGHT]}, + 0); + at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight; + at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros({1}).to(device_); + + // merge qkv bias + auto new_qkv_bias = torch::cat({at_weight_tensors_[IN_VISION_Q_BIAS], + at_weight_tensors_[IN_VISION_K_BIAS], + at_weight_tensors_[IN_VISION_V_BIAS]}, + 0); + at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias; + at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros({1}).to(device_); + } + // pad qkv weights + pad_qkv_weights(); + // merge gate up + auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_GATE_WEIGHT], + at_weight_tensors_[IN_MLP_UP_WEIGHT]}, + 0); + at_weight_tensors_[IN_MLP_GATE_WEIGHT] = new_mlp_weight; + auto new_mlp_bias = torch::cat({at_weight_tensors_[IN_MLP_GATE_BIAS], + at_weight_tensors_[IN_MLP_UP_BIAS]}, + 0); + at_weight_tensors_[IN_MLP_GATE_BIAS] = new_mlp_bias; + at_weight_tensors_[IN_MLP_UP_BIAS] = torch::zeros({1}).to(device_); + // pad mlp weights + pad_mlp_weights(); +} + +void Qwen2dot5VisionEncoderLoader::pad_qkv_weights() { + auto qkv_proj_weight = at_weight_tensors_[IN_QKV_WEIGHT]; + auto qkv_proj_bias = at_weight_tensors_[IN_QKV_BIAS]; + int num_heads_pre_rank = encode_param_numAttentionHeadsPerRank; + int hidden_size = num_heads_pre_rank * 80 * encode_param_worldSize; + + auto qkv_proj_weight_reshaped = + qkv_proj_weight.reshape({num_heads_pre_rank, 3, 80, hidden_size}); + + auto first_half = qkv_proj_weight_reshaped.slice(2, 0, 40); + auto second_half = qkv_proj_weight_reshaped.slice(2, 40, 80); + + auto first_half_padded = torch::nn::functional::pad( + first_half, torch::nn::functional::PadFuncOptions({0, 0, 0, 24})); + auto second_half_padded = torch::nn::functional::pad( + second_half, torch::nn::functional::PadFuncOptions({0, 0, 0, 24})); + + auto qkv_proj_weight_padded = + torch::cat({first_half_padded, second_half_padded}, 2); + auto qkv_proj_weight_final = qkv_proj_weight_padded.reshape( + {num_heads_pre_rank * 128 * 3, hidden_size}); + qkv_proj_weight_final = + at_npu::native::npu_format_cast(qkv_proj_weight_final, 2); + + auto qkv_proj_bias_reshaped = + qkv_proj_bias.reshape({num_heads_pre_rank, 3, 80}); + first_half = qkv_proj_bias_reshaped.slice(2, 0, 40); + second_half = qkv_proj_bias_reshaped.slice(2, 40, 80); + + first_half_padded = torch::nn::functional::pad( + first_half, torch::nn::functional::PadFuncOptions({0, 24})); + second_half_padded = torch::nn::functional::pad( + second_half, torch::nn::functional::PadFuncOptions({0, 24})); + auto qkv_proj_bias_padded = + torch::cat({first_half_padded, second_half_padded}, 2); + auto qkv_proj_bias_final = + qkv_proj_bias_padded.reshape({num_heads_pre_rank * 128 * 3}); + + at_weight_tensors_[IN_QKV_WEIGHT] = qkv_proj_weight_final; + at_weight_tensors_[IN_QKV_BIAS] = qkv_proj_bias_final; + + auto out_proj_weight = at_weight_tensors_[IN_WATTENTION_OUT_WEIGHT]; + + out_proj_weight = + torch::nn::functional::pad( + out_proj_weight.reshape({hidden_size, num_heads_pre_rank * 2, 40}), + torch::nn::functional::PadFuncOptions({0, 24, 0, 0})) + .reshape({hidden_size, num_heads_pre_rank * 128}); + at_weight_tensors_[IN_WATTENTION_OUT_WEIGHT] = out_proj_weight; +} + +void Qwen2dot5VisionEncoderLoader::pad_mlp_weights() { + torch::Tensor weight = at_weight_tensors_[IN_MLP_GATE_WEIGHT]; + torch::Tensor bias = at_weight_tensors_[IN_MLP_GATE_BIAS]; + + int64_t tp_intermediate_size_half = weight.size(0) / 2; + int64_t remainder = tp_intermediate_size_half % 32; + int64_t tp_intermediate_size_half_pad; + if (remainder != 0) { + tp_intermediate_size_half_pad = + tp_intermediate_size_half + (32 - remainder); + } else { + tp_intermediate_size_half_pad = tp_intermediate_size_half; + } + auto weight_split1 = weight.slice(0, 0, tp_intermediate_size_half); + auto weight_split2 = weight.slice(0, tp_intermediate_size_half); + auto bias_split1 = bias.slice(0, 0, tp_intermediate_size_half); + auto bias_split2 = bias.slice(0, tp_intermediate_size_half); + + auto weight_split1_padded = + pad_tensor(weight_split1, tp_intermediate_size_half_pad); + auto weight_split2_padded = + pad_tensor(weight_split2, tp_intermediate_size_half_pad); + auto bias_split1_padded = + pad_tensor(bias_split1, tp_intermediate_size_half_pad); + auto bias_split2_padded = + pad_tensor(bias_split2, tp_intermediate_size_half_pad); + + auto weight_padded = + torch::cat({weight_split1_padded, weight_split2_padded}, 0); + auto bias_padded = torch::cat({bias_split1_padded, bias_split2_padded}, 0); + at_weight_tensors_[IN_MLP_GATE_WEIGHT] = weight_padded; + at_weight_tensors_[IN_MLP_GATE_BIAS] = bias_padded; + + torch::Tensor down_weight = at_weight_tensors_[IN_MLP_DOWN_WEIGHT]; + + auto tp_intermediate_size = down_weight.size(1); + remainder = tp_intermediate_size % 32; + int64_t tp_intermediate_size_pad; + if (remainder != 0) { + tp_intermediate_size_pad = tp_intermediate_size + (32 - remainder); + } else { + tp_intermediate_size_pad = tp_intermediate_size; + } + + auto down_weight_padded = + pad_tensor(down_weight, tp_intermediate_size_pad, 1); + at_weight_tensors_[IN_MLP_DOWN_WEIGHT] = down_weight_padded; +} + +torch::Tensor Qwen2dot5VisionEncoderLoader::pad_tensor( + const torch::Tensor& tensor, + int64_t target_shape, + int64_t dim) { + int64_t pad_size = target_shape - tensor.size(dim); + if (tensor.dim() == 1) { + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, pad_size})); + } else if (tensor.dim() == 2) { + if (1 == dim) + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, pad_size, 0, 0})); + else + return torch::nn::functional::pad( + tensor, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); + } + return tensor; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/qwen2dot5_vision_encoder_loader.h b/xllm/core/layers/npu/loader/qwen2dot5_vision_encoder_loader.h new file mode 100644 index 000000000..31dabdb67 --- /dev/null +++ b/xllm/core/layers/npu/loader/qwen2dot5_vision_encoder_loader.h @@ -0,0 +1,57 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "base_loader.h" + +namespace xllm { +namespace layer { + +class Qwen2dot5VisionEncoderLoader : public BaseLoader { + public: + Qwen2dot5VisionEncoderLoader(uint64_t weight_count, + const ModelContext& context, + int64_t numAttentionHeadsPerRank); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights() const override; + void merge_loaded_weights() override; + + private: + void get_weights_col_packed_qkv(); + void pad_qkv_weights(); + void pad_mlp_weights(); + torch::Tensor pad_tensor(const torch::Tensor& tensor, + int64_t target_shape, + int64_t dim = 0); + + protected: + std::string model_name_; + torch::Tensor cu_seqlen_; + torch::Tensor at_placeholder_; + std::vector qkv_weight; + std::vector qkv_bias; + int device_id_; + int encode_param_rank; + int encode_param_worldSize; + int64_t encode_param_numAttentionHeadsPerRank; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/qwen3_decoder_loader.cpp b/xllm/core/layers/npu/loader/qwen3_decoder_loader.cpp new file mode 100644 index 000000000..6889bdd81 --- /dev/null +++ b/xllm/core/layers/npu/loader/qwen3_decoder_loader.cpp @@ -0,0 +1,333 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "qwen3_decoder_loader.h" + +namespace xllm { +namespace layer { + +enum DecoderLayerTensorId : int { + IN_NORM_WEIGHT = 0, // weight + IN_NORM_BIAS = 1, // bias + IN_NORM_NEW_WEIGHT = 2, // new weight + IN_NORM_NEW_BIAS = 3, // new bias + + IN_Q_WEIGHT = 4, // weight + IN_Q_BIAS = 5, // bias + IN_Q_DEQSCALE = 6, // deq_scale + IN_Q_OFFSET = 7, // offset + IN_Q_SCALE = 8, // scale + IN_Q_COMPRESS_IDX = 9, + + IN_K_WEIGHT = 10, // weight + IN_K_BIAS = 11, // bias + IN_K_DEQSCALE = 12, // deq_scale + IN_K_OFFSET = 13, // offset + IN_K_SCALE = 14, // scale + IN_K_COMPRESS_IDX = 15, + + IN_V_WEIGHT = 16, // weight + IN_V_BIAS = 17, // bias + IN_V_DEQSCALE = 18, // deq_scale + IN_V_OFFSET = 19, // offset + IN_V_SCALE = 20, // scale + IN_V_COMPRESS_IDX = 21, + + IN_ATTENTION_OUT_WEIGHT = 22, // weight + IN_ATTENTION_OUT_BIAS = 23, // bias + IN_ATTENTION_OUT_DEQSCALE = 24, // deq_scale + IN_ATTENTION_OUT_OFFSET = 25, // offset + IN_ATTENTION_OUT_SCALE = 26, // scale + IN_ATTENTION_OUT_COMPRESS_IDX = 27, + + IN_SELFOUT_NORM_WEIGHT = 28, // weight + IN_SELFOUT_NORM_BIAS = 29, // bias + IN_SELFOUT_NORM_NEW_WEIGHT = 30, // new weight + IN_SELFOUT_NORM_NEW_BIAS = 31, // new bias + + IN_MLP_W2_WEIGHT = 32, // weight + IN_MLP_W2_BIAS = 33, // bias + IN_MLP_W2_DEQSCALE = 34, // deq_scale + IN_MLP_W2_OFFSET = 35, // offset + IN_MLP_W2_SCALE = 36, // scale + IN_MLP_W2_COMPRESS_IDX = 37, + + IN_MLP_W1_WEIGHT = 38, // weight + IN_MLP_W1_BIAS = 39, // bias + IN_MLP_W1_DEQSCALE = 40, // deq_scale + IN_MLP_W1_OFFSET = 41, // offset + IN_MLP_W1_SCALE = 42, // scale + IN_MLP_W1_COMPRESS_IDX = 43, + + IN_MLP_CPROJ_WEIGHT = 44, // weight + IN_MLP_CPROJ_BIAS = 45, // bias + IN_MLP_CPROJ_DEQSCALE = 46, // deq_scale + IN_MLP_CPROJ_OFFSET = 47, // offset + IN_MLP_CPROJ_SCALE = 48, // scale + IN_MLP_CPROJ_COMPRESS_IDX = 49, + + IN_QKV_SCALE_FILL = 50, + IN_QKV_OFFSET_FILL = 51, + IN_MLP_SCALE_FILL = 52, + IN_MLP_OFFSET_FILL = 53, + Q_NORM_WEIGHT = 54, + K_NORM_WEIGHT = 55, +}; + +static std::vector> WEIGHT_MAPPING = { + {IN_NORM_WEIGHT, "input_layernorm.weight"}, + {IN_Q_WEIGHT, "self_attn.q_proj.weight"}, + {IN_K_WEIGHT, "self_attn.k_proj.weight"}, + {IN_V_WEIGHT, "self_attn.v_proj.weight"}, + {IN_ATTENTION_OUT_WEIGHT, "self_attn.o_proj.weight"}, + {IN_SELFOUT_NORM_WEIGHT, "post_attention_layernorm.weight"}, + {IN_MLP_W2_WEIGHT, "mlp.gate_proj.weight"}, + {IN_MLP_W1_WEIGHT, "mlp.up_proj.weight"}, + {IN_MLP_CPROJ_WEIGHT, "mlp.down_proj.weight"}, + {Q_NORM_WEIGHT, "self_attn.q_norm.weight"}, + {K_NORM_WEIGHT, "self_attn.k_norm.weight"}}; + +static std::vector> WEIGHT_MAPPING_W8A8 = { + {IN_NORM_WEIGHT, "input_layernorm.weight"}, + {IN_Q_WEIGHT, "self_attn.q_proj.weight"}, + {IN_Q_BIAS, "self_attn.q_proj.quant_bias"}, + {IN_Q_DEQSCALE, "self_attn.q_proj.deq_scale"}, + {IN_Q_OFFSET, "self_attn.q_proj.input_offset"}, + {IN_Q_SCALE, "self_attn.q_proj.input_scale"}, + {IN_K_WEIGHT, "self_attn.k_proj.weight"}, + {IN_K_BIAS, "self_attn.k_proj.quant_bias"}, + {IN_K_DEQSCALE, "self_attn.k_proj.deq_scale"}, + {IN_K_OFFSET, "self_attn.k_proj.input_offset"}, + {IN_K_SCALE, "self_attn.k_proj.input_scale"}, + {IN_V_WEIGHT, "self_attn.v_proj.weight"}, + {IN_V_BIAS, "self_attn.v_proj.quant_bias"}, + {IN_V_DEQSCALE, "self_attn.v_proj.deq_scale"}, + {IN_V_OFFSET, "self_attn.v_proj.input_offset"}, + {IN_V_SCALE, "self_attn.v_proj.input_scale"}, + {IN_ATTENTION_OUT_WEIGHT, "self_attn.o_proj.weight"}, + {IN_ATTENTION_OUT_BIAS, "self_attn.o_proj.quant_bias"}, + {IN_ATTENTION_OUT_DEQSCALE, "self_attn.o_proj.deq_scale"}, + {IN_ATTENTION_OUT_OFFSET, "self_attn.o_proj.input_offset"}, + {IN_ATTENTION_OUT_SCALE, "self_attn.o_proj.input_scale"}, + {IN_SELFOUT_NORM_WEIGHT, "post_attention_layernorm.weight"}, + {IN_MLP_W2_WEIGHT, "mlp.gate_proj.weight"}, + {IN_MLP_W2_BIAS, "mlp.gate_proj.quant_bias"}, + {IN_MLP_W2_DEQSCALE, "mlp.gate_proj.deq_scale"}, + {IN_MLP_W2_OFFSET, "mlp.gate_proj.input_offset"}, + {IN_MLP_W2_SCALE, "mlp.gate_proj.input_scale"}, + {IN_MLP_W1_WEIGHT, "mlp.up_proj.weight"}, + {IN_MLP_W1_BIAS, "mlp.up_proj.quant_bias"}, + {IN_MLP_W1_DEQSCALE, "mlp.up_proj.deq_scale"}, + {IN_MLP_W1_OFFSET, "mlp.up_proj.input_offset"}, + {IN_MLP_W1_SCALE, "mlp.up_proj.input_scale"}, + {IN_MLP_CPROJ_WEIGHT, "mlp.down_proj.weight"}, + {Q_NORM_WEIGHT, "self_attn.q_norm.weight"}, + {K_NORM_WEIGHT, "self_attn.k_norm.weight"}}; + +static std::map WEIGHT_SHARD = {{IN_Q_WEIGHT, 0}, + {IN_K_WEIGHT, 0}, + {IN_V_WEIGHT, 0}, + {IN_ATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_W2_WEIGHT, 0}, + {IN_MLP_W1_WEIGHT, 0}, + {IN_MLP_CPROJ_WEIGHT, 1}}; + +static std::map WEIGHT_SHARD_W8A8 = {{IN_Q_WEIGHT, 0}, + {IN_Q_BIAS, 0}, + {IN_Q_DEQSCALE, 0}, + {IN_K_WEIGHT, 0}, + {IN_K_BIAS, 0}, + {IN_K_DEQSCALE, 0}, + {IN_V_WEIGHT, 0}, + {IN_V_BIAS, 0}, + {IN_V_DEQSCALE, 0}, + {IN_ATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_W2_WEIGHT, 0}, + {IN_MLP_W2_BIAS, 0}, + {IN_MLP_W2_DEQSCALE, 0}, + {IN_MLP_W1_WEIGHT, 0}, + {IN_MLP_W1_BIAS, 0}, + {IN_MLP_W1_DEQSCALE, 0}, + {IN_MLP_CPROJ_WEIGHT, 1}}; + +Qwen3DecoderLoader::Qwen3DecoderLoader(uint64_t weight_count, + const ModelContext& context, + bool enableAddNorm) + : BaseLoader(weight_count, context), enableAddNorm_(enableAddNorm) { + auto options = context.get_tensor_options(); + rank_id_ = parallel_args_.rank(); + + dtype_ = torch::typeMetaToScalarType(options.dtype()); + for (int i = 0; i < weight_count; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } + + at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); +} + +void Qwen3DecoderLoader::load_state_dict(const StateDict& state_dict) { + if (quantize_type_.compare("w8a8") == 0) { + for (const auto& [index, name] : WEIGHT_MAPPING_W8A8) { + if (WEIGHT_SHARD_W8A8.find(index) != WEIGHT_SHARD_W8A8.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD_W8A8[index]); + } else { + set_weight(state_dict, name, index); + } + } + at_weight_tensors_[IN_NORM_BIAS] = + torch::zeros(at_weight_tensors_[IN_NORM_WEIGHT].sizes(), + at_weight_tensors_[IN_NORM_WEIGHT].options()) + .to(device_); + + at_weight_tensors_[IN_SELFOUT_NORM_BIAS] = + torch::zeros(at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].sizes(), + at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].options()) + .to(device_); + return; + } + + for (const auto& [index, name] : WEIGHT_MAPPING) { + if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + } +} + +void Qwen3DecoderLoader::verify_loaded_weights() const { + for (const auto& [index, name] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +void Qwen3DecoderLoader::merge_loaded_weights() { + if (quantize_type_.compare("w8a8") == 0) { + at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE] = + at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE].to(torch::kFloat32); + at_weight_tensors_[IN_Q_DEQSCALE] = + torch::cat({at_weight_tensors_[IN_Q_DEQSCALE], + at_weight_tensors_[IN_K_DEQSCALE], + at_weight_tensors_[IN_V_DEQSCALE]}, + 0) + .to(torch::kFloat32); + + at_weight_tensors_[IN_Q_BIAS] = torch::cat({at_weight_tensors_[IN_Q_BIAS], + at_weight_tensors_[IN_K_BIAS], + at_weight_tensors_[IN_V_BIAS]}, + 0) + .to(torch::kInt32); + + for (auto idx : {IN_K_DEQSCALE, + IN_V_DEQSCALE, + IN_K_BIAS, + IN_V_BIAS, + IN_K_OFFSET, + IN_V_OFFSET, + IN_K_SCALE, + IN_V_SCALE}) { + at_weight_tensors_[idx] = at_placeholder_; + } + + at_weight_tensors_[IN_MLP_W2_BIAS] = + torch::cat({at_weight_tensors_[IN_MLP_W2_BIAS], + at_weight_tensors_[IN_MLP_W1_BIAS]}, + 0); + + at_weight_tensors_[IN_MLP_W2_DEQSCALE] = + torch::cat({at_weight_tensors_[IN_MLP_W2_DEQSCALE], + at_weight_tensors_[IN_MLP_W1_DEQSCALE]}, + 0) + .to(torch::kFloat32); + + for (auto idx : {IN_MLP_W1_BIAS, + IN_MLP_W1_OFFSET, + IN_MLP_W1_SCALE, + IN_MLP_W1_DEQSCALE}) { + at_weight_tensors_[idx] = at_placeholder_; + } + + at_weight_tensors_[IN_Q_OFFSET] = + at_weight_tensors_[IN_Q_OFFSET].to(torch::kInt8).to(device_); + at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] = + at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] + .to(torch::kInt8) + .to(device_); + at_weight_tensors_[IN_MLP_W2_OFFSET] = + at_weight_tensors_[IN_MLP_W2_OFFSET].to(torch::kInt8).to(device_); + + if (rank_id_ != 0) { + torch::Tensor original_tensor = at_weight_tensors_[IN_ATTENTION_OUT_BIAS]; + auto shape = original_tensor.sizes(); + auto dtype = original_tensor.dtype(); + auto device = original_tensor.device(); + + at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = torch::zeros( + shape, torch::TensorOptions().dtype(dtype).device(device)); + } + } + + at_weight_tensors_[IN_Q_WEIGHT] = + torch::cat({at_weight_tensors_[IN_Q_WEIGHT], + at_weight_tensors_[IN_K_WEIGHT], + at_weight_tensors_[IN_V_WEIGHT]}, + 0) + .contiguous(); + + at_weight_tensors_[IN_MLP_W2_WEIGHT] = + torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], + at_weight_tensors_[IN_MLP_W1_WEIGHT]}, + 0) + .contiguous(); + + for (auto idx : + {IN_MLP_W1_WEIGHT, IN_K_WEIGHT, IN_V_WEIGHT, IN_K_BIAS, IN_V_BIAS}) { + at_weight_tensors_[idx] = at_placeholder_; + } + + if (enableAddNorm_) { + if (quantize_type_.compare("w8a8") == 0) { + // quantize + torch::ScalarType weight_fill_dtype = torch::kBFloat16; + int64_t weight_attn_shape = at_weight_tensors_[IN_Q_WEIGHT].size(-1); + int64_t weight_mlp_shape = at_weight_tensors_[IN_MLP_W2_WEIGHT].size(-1); + at_weight_tensors_[IN_QKV_SCALE_FILL] = at_weight_tensors_[IN_Q_SCALE] + .repeat(weight_attn_shape) + .to(weight_fill_dtype); + at_weight_tensors_[IN_MLP_SCALE_FILL] = + at_weight_tensors_[IN_MLP_W2_SCALE] + .repeat(weight_mlp_shape) + .to(weight_fill_dtype); + at_weight_tensors_[IN_QKV_OFFSET_FILL] = at_weight_tensors_[IN_Q_OFFSET] + .repeat(weight_attn_shape) + .to(weight_fill_dtype); + at_weight_tensors_[IN_MLP_OFFSET_FILL] = + at_weight_tensors_[IN_MLP_W2_OFFSET] + .repeat(weight_mlp_shape) + .to(weight_fill_dtype); + } else { + // bfloat16 or float16 + for (auto idx : {IN_QKV_SCALE_FILL, + IN_QKV_OFFSET_FILL, + IN_MLP_SCALE_FILL, + IN_MLP_OFFSET_FILL}) { + at_weight_tensors_[idx] = at_placeholder_; + } + } + } +} +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/qwen3_decoder_loader.h b/xllm/core/layers/npu/loader/qwen3_decoder_loader.h new file mode 100644 index 000000000..412a5c264 --- /dev/null +++ b/xllm/core/layers/npu/loader/qwen3_decoder_loader.h @@ -0,0 +1,43 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "base_loader.h" + +namespace xllm { +namespace layer { + +class Qwen3DecoderLoader : public BaseLoader { + public: + Qwen3DecoderLoader(uint64_t weight_count, + const ModelContext& context, + bool enableAddNorm); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights() const override; + void merge_loaded_weights() override; + + protected: + torch::Tensor at_placeholder_; + bool enableAddNorm_; + int rank_id_; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/qwen3_moe_decoder_loader.cpp b/xllm/core/layers/npu/loader/qwen3_moe_decoder_loader.cpp new file mode 100644 index 000000000..74465fe4e --- /dev/null +++ b/xllm/core/layers/npu/loader/qwen3_moe_decoder_loader.cpp @@ -0,0 +1,635 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "qwen3_moe_decoder_loader.h" + +#include +#include +#include + +#include "xllm_kernels/core/include/atb_speed/base/hosttensor_binder.h" +#include "xllm_kernels/core/include/atb_speed/base/model.h" +#include "xllm_kernels/core/include/atb_speed/log.h" +#include "xllm_kernels/core/include/atb_speed/utils/model_factory.h" +#include "xllm_kernels/models/qwen3/layer/moe_decoder_layer.h" + +namespace xllm { +namespace layer { +enum DecoderLayerTensorId : int { + IN_INPUT_NORM_WEIGHT = 0, // [2048] + IN_INPUT_NORM_BIAS = 1, + IN_INPUT_NORM_NEW_WEIGHT = 2, + IN_INPUT_NORM_NEW_BIAS = 3, + + IN_QKV_WEIGHT_0 = 4, // [4096, 2048] + IN_QKV_BIAS_0 = 5, + IN_QKV_DESCALE_0 = 6, + IN_QKV_OFFSET_0 = 7, + IN_QKV_SCALE_0 = 8, + IN_QKV_COMPRESS_IDX_0 = 9, + + IN_QKV_WEIGHT_1 = 10, // [512, 2048] + IN_QKV_BIAS_1 = 11, + IN_QKV_DESCALE_1 = 12, + IN_QKV_OFFSET_1 = 13, + IN_QKV_SCALE_1 = 14, + IN_QKV_COMPRESS_IDX_1 = 15, + + IN_QKV_WEIGHT_2 = 16, // [512, 2048] + IN_QKV_BIAS_2 = 17, + IN_QKV_DESCALE_2 = 18, + IN_QKV_OFFSET_2 = 19, + IN_QKV_SCALE_2 = 20, + IN_QKV_COMPRESS_IDX_2 = 21, + + IN_ATTENTION_OUT_WEIGHT = 22, // [2048, 4096] + IN_ATTENTION_OUT_BIAS = 23, + IN_ATTENTION_OUT_DESCALE = 24, + IN_ATTENTION_OUT_OFFSET = 25, + IN_ATTENTION_OUT_SCALE = 26, + IN_ATTENTION_OUT_COMPRESS_IDX = 27, + + IN_Q_NORM_WEIGHT = 28, // [128] + IN_K_NORM_WEIGHT = 29, // [128] + + IN_SELFATTENTION_OUT_NORM_WEIGHT = 30, // [2048] + IN_SELFATTENTION_OUT_NORM_BIAS = 31, + IN_SELFATTENTION_OUT_NEW_NORM_WEIGHT = 32, + IN_SELFATTENTION_OUT_NEW_NORM_BIAS = 33, + + IN_BLOCK_SPARSE_MOE_GATE_WEIGHT = 34, // [128, 2048] + IN_BLOCK_SPARSE_MOE_GATE_BIAS = 35, + IN_BLOCK_SPARSE_MOE_GATE_DESCALE = 36, + IN_BLOCK_SPARSE_MOE_GATE_OFFSET = 37, + IN_BLOCK_SPARSE_MOE_GATE_SCALE = 38, + IN_BLOCK_SPARSE_MOE_GATE_COMPRESS_IDX = 39, + + IN_MLP_GATEUP_WEIGHT_EXPERT = 40, + IN_MLP_GATEUP_BIAS_EXPERT = 41, + IN_MLP_GATEUP_DESCALE_EXPERT = 42, + IN_MLP_GATEUP_OFFSET_EXPERT = 43, + IN_MLP_GATEUP_SCALE_EXPERT = 44, + IN_MLP_GATEUP_COMPRESS_IDX_EXPERT = 45, + + IN_MLP_DOWN_WEIGHT_EXPERT = 46, // [2048, 768] + IN_MLP_DOWN_BIAS_EXPERT = 47, + IN_MLP_DOWN_DESCALE_EXPERT = 48, + IN_MLP_DOWN_OFFSET_EXPERT = 49, + IN_MLP_DOWN_SCALE_EXPERT = 50, + IN_MLP_DOWN_COMPRESS_IDX_EXPERT = 51, + + IN_MLP_SHARED_GATEUP_WEIGHT = 52, + IN_MLP_SHARED_DOWN_WEIGHT = 53, + IN_MLP_SHARED_EXPERT_GATE = 54, +}; + +static const std::unordered_map WEIGHT_MAPPING = { + {"input_layernorm.weight", IN_INPUT_NORM_WEIGHT}, + + {"self_attn.q_proj.weight", IN_QKV_WEIGHT_0}, + + {"self_attn.k_proj.weight", IN_QKV_WEIGHT_1}, + + {"self_attn.v_proj.weight", IN_QKV_WEIGHT_2}, + + {"self_attn.o_proj.weight", IN_ATTENTION_OUT_WEIGHT}, + + {"self_attn.q_norm.weight", IN_Q_NORM_WEIGHT}, + {"self_attn.k_norm.weight", IN_K_NORM_WEIGHT}, + + {"post_attention_layernorm.weight", IN_SELFATTENTION_OUT_NORM_WEIGHT}, + + // MoE Gate + {"mlp.gate.weight", IN_BLOCK_SPARSE_MOE_GATE_WEIGHT}, + + // Expert MLP - Gate/Up projections + {"gate_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, + + {"up_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, + + // Expert MLP - Down projection + {"down_proj.weight", IN_MLP_DOWN_WEIGHT_EXPERT}, + +}; + +static const std::unordered_map WEIGHT_MAPPING_W8A8 = { + {"input_layernorm.weight", IN_INPUT_NORM_WEIGHT}, + {"input_layernorm.bias", IN_INPUT_NORM_NEW_BIAS}, + + {"self_attn.q_proj.weight", IN_QKV_WEIGHT_0}, + {"self_attn.q_proj.bias", IN_QKV_BIAS_0}, + {"self_attn.q_proj.deq_scale", IN_QKV_DESCALE_0}, + {"self_attn.q_proj.weight_offset", IN_QKV_OFFSET_0}, + {"self_attn.q_proj.weight_scale", IN_QKV_SCALE_0}, + + {"self_attn.k_proj.weight", IN_QKV_WEIGHT_1}, + {"self_attn.k_proj.bias", IN_QKV_BIAS_1}, + {"self_attn.k_proj.deq_scale", IN_QKV_DESCALE_1}, + {"self_attn.k_proj.weight_offset", IN_QKV_OFFSET_1}, + {"self_attn.k_proj.weight_scale", IN_QKV_SCALE_1}, + + {"self_attn.v_proj.weight", IN_QKV_WEIGHT_2}, + {"self_attn.v_proj.bias", IN_QKV_BIAS_2}, + {"self_attn.v_proj.deq_scale", IN_QKV_DESCALE_2}, + {"self_attn.v_proj.weight_offset", IN_QKV_OFFSET_2}, + {"self_attn.v_proj.weight_scale", IN_QKV_SCALE_2}, + + {"self_attn.o_proj.weight", IN_ATTENTION_OUT_WEIGHT}, + {"self_attn.o_proj.quant_bias", IN_ATTENTION_OUT_BIAS}, + {"self_attn.o_proj.deq_scale", IN_ATTENTION_OUT_DESCALE}, + {"self_attn.o_proj.weight_offset", IN_ATTENTION_OUT_OFFSET}, + {"self_attn.o_proj.weight_scale", IN_ATTENTION_OUT_SCALE}, + + {"self_attn.q_norm.weight", IN_Q_NORM_WEIGHT}, + {"self_attn.k_norm.weight", IN_K_NORM_WEIGHT}, + + {"post_attention_layernorm.weight", IN_SELFATTENTION_OUT_NORM_WEIGHT}, + {"post_attention_layernorm.bias", IN_SELFATTENTION_OUT_NEW_NORM_BIAS}, + + // MoE Gate + {"mlp.gate.weight", IN_BLOCK_SPARSE_MOE_GATE_WEIGHT}, + + {"gate_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, + {"gate_proj.weight_offset", IN_MLP_GATEUP_OFFSET_EXPERT}, + {"gate_proj.weight_scale", IN_MLP_GATEUP_SCALE_EXPERT}, + {"up_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, + {"up_proj.weight_offset", IN_MLP_GATEUP_OFFSET_EXPERT}, + {"up_proj.weight_scale", IN_MLP_GATEUP_SCALE_EXPERT}, + + {"down_proj.weight", IN_MLP_DOWN_WEIGHT_EXPERT}, + {"down_proj.weight_offset", IN_MLP_DOWN_OFFSET_EXPERT}, + {"down_proj.weight_scale", IN_MLP_DOWN_SCALE_EXPERT}, +}; + +static const std::unordered_map> + SPECIAL_MULTI_ASSIGN_W8A8 = { + {"input_layernorm.weight", + {IN_INPUT_NORM_WEIGHT, IN_INPUT_NORM_NEW_WEIGHT}}, + {"post_attention_layernorm.weight", + {IN_SELFATTENTION_OUT_NORM_WEIGHT, + IN_SELFATTENTION_OUT_NEW_NORM_WEIGHT}}, +}; + +static const std::map WEIGHT_SHARD = { + {IN_QKV_WEIGHT_0, 0}, + {IN_QKV_WEIGHT_1, 0}, + {IN_QKV_WEIGHT_2, 0}, + {IN_ATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_GATEUP_WEIGHT_EXPERT, 0}, + {IN_MLP_DOWN_WEIGHT_EXPERT, 1}, +}; + +static const std::map WEIGHT_SHARD_W8A8 = { + {IN_QKV_WEIGHT_0, 0}, + {IN_QKV_OFFSET_0, 0}, + {IN_QKV_SCALE_0, 0}, + {IN_QKV_WEIGHT_1, 0}, + {IN_QKV_OFFSET_1, 0}, + {IN_QKV_SCALE_1, 0}, + {IN_QKV_WEIGHT_2, 0}, + {IN_QKV_OFFSET_2, 0}, + {IN_QKV_SCALE_2, 0}, + {IN_ATTENTION_OUT_WEIGHT, 1}, + {IN_MLP_GATEUP_WEIGHT_EXPERT, 0}, + {IN_MLP_GATEUP_OFFSET_EXPERT, 0}, + {IN_MLP_GATEUP_SCALE_EXPERT, 0}, + {IN_MLP_DOWN_WEIGHT_EXPERT, 1}, +}; + +Qwen3MoeDecoderLoader::Qwen3MoeDecoderLoader(uint64_t weight_count, + const ModelContext& context) + : BaseLoader(weight_count, context) { + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + + for (int i = 0; i < weight_count; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } + + num_experts_ = model_args.num_experts(); + ep_size_ = parallel_args_.ep_size(); + ep_local_tp_size_ = parallel_args_.world_size() / ep_size_; + CHECK_EQ(parallel_args_.world_size(), ep_size_ * ep_local_tp_size_); + ep_local_tp_rank_ = parallel_args_.rank() % ep_local_tp_size_; + num_experts_per_partition_ = model_args.num_experts() / ep_size_; + ep_rank_ = parallel_args_.rank() / ep_local_tp_size_; + start_expert_id_ = ep_rank_ * num_experts_per_partition_; + end_expert_id_ = start_expert_id_ + num_experts_per_partition_ - 1; + n_kv_heads_ = static_cast(model_args.n_kv_heads().value()); + + dp_size_ = parallel_args_.dp_size(); + dp_local_tp_size_ = parallel_args_.world_size() / dp_size_; + CHECK_EQ(parallel_args_.world_size(), dp_size_ * dp_local_tp_size_); + dp_local_tp_rank_ = parallel_args_.rank() % dp_local_tp_size_; +} + +void Qwen3MoeDecoderLoader::load_state_dict(const StateDict& state_dict) { + for (const auto& [name, tensor] : state_dict) { + bool is_sharded = false; + int index = 0; + + if (absl::StartsWith(name, "mlp.experts")) { + process_expert_weights(state_dict, name, tensor); + continue; + } + + if (absl::StartsWith(name, "mlp") && !absl::StrContains(name, "gate.")) { + process_mlp_common_weights(state_dict, name, tensor); + continue; + } + + process_general_weights(state_dict, name, tensor); + } +} + +void Qwen3MoeDecoderLoader::verify_loaded_weights( + const std::string& prefix) const { + for (const auto& [name, index] : WEIGHT_MAPPING) { + if (name == "down_proj.weight" || name == "gate_proj.weight" || + name == "up_proj.weight") { + continue; + } + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +void Qwen3MoeDecoderLoader::merge_experts_weights() { + if (experts_weights_.count("gate_proj.weight") > 0) { + auto& gate_weight = experts_weights_["gate_proj.weight"]; + } + + if (experts_weights_.count("up_proj.weight") > 0) { + auto& up_weight = experts_weights_["up_proj.weight"]; + } + + try { + torch::Tensor mlp_gateup_weight; + if (quantize_type_.compare("w8a8_dynamic") == 0) { + mlp_gateup_weight = + merge_experts_weights(experts_weights_["gate_proj.weight"], + experts_weights_["up_proj.weight"], + /*transpose=*/true); + at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = + merge_experts_weights(experts_weights_["gate_proj.weight_offset"], + experts_weights_["up_proj.weight_offset"]); + at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = + merge_experts_weights(experts_weights_["gate_proj.weight_scale"], + experts_weights_["up_proj.weight_scale"]); + } else { + mlp_gateup_weight = + merge_experts_weights(experts_weights_["gate_proj.weight"], + experts_weights_["up_proj.weight"], + /*transpose=*/false); + } + at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT] = + at_npu::native::npu_format_cast(mlp_gateup_weight, 2).contiguous(); + } catch (const std::exception& e) { + LOG(ERROR) << "[ERROR] Exception in gateup weight processing: " << e.what(); + throw; + } + + if (experts_weights_.count("down_proj.weight") > 0) { + auto& down_weight = experts_weights_["down_proj.weight"]; + } + + try { + torch::Tensor mlp_down_weight = + merge_experts_weights(experts_weights_["down_proj.weight"], + /*transpose=*/false); + + at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = + at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); + + if (quantize_type_.compare("w8a8_dynamic") == 0) { + at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = + merge_experts_weights(experts_weights_["down_proj.weight_offset"]); + at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = + merge_experts_weights(experts_weights_["down_proj.weight_scale"]); + } + } catch (const std::exception& e) { + LOG(ERROR) << "[ERROR] Exception in down weight processing: " << e.what(); + throw; + } +} + +torch::Tensor Qwen3MoeDecoderLoader::merge_experts_weights( + std::vector& experts, + bool transpose) { + torch::Tensor merged_tensor = torch::stack(experts, 0).to(device_); + if (transpose) { + merged_tensor = merged_tensor.transpose(1, 2); + } + merged_tensor = merged_tensor.contiguous(); + experts.clear(); + + return merged_tensor; +} + +std::string Qwen3MoeDecoderLoader::extract_endswith(const std::string& input) { + std::vector parts; + std::stringstream ss(input); + std::string part; + while (std::getline(ss, part, '.')) { + parts.push_back(part); + } + if (parts.size() < 2) { + return ""; + } + std::string result = parts[parts.size() - 2] + "." + parts[parts.size() - 1]; + + return result; +} + +int Qwen3MoeDecoderLoader::extract_expert_index(const std::string& name) { + std::string prefix = "experts."; + size_t pos = name.find(prefix); + if (pos != std::string::npos) { + pos += prefix.length(); + size_t end_pos = pos; + while (end_pos < name.length() && std::isdigit(name[end_pos])) { + ++end_pos; + } + if (end_pos > pos) { + return std::stoi(name.substr(pos, end_pos - pos)); + } + } + + return -1; +} + +void Qwen3MoeDecoderLoader::resize_experts_weights(int num_of_device_experts) { + experts_weights_["gate_proj.weight"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight"] = + std::vector(num_of_device_experts); + if (quantize_type_.compare("w8a8_dynamic") == 0) { + experts_weights_["gate_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["gate_proj.weight_scale"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight_scale"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight_scale"] = + std::vector(num_of_device_experts); + } +} + +void Qwen3MoeDecoderLoader::process_expert_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + int expert_index = extract_expert_index(name); + if (expert_index < start_expert_id_ || expert_index > end_expert_id_) { + return; + } + + const std::string suffix = extract_endswith(name); + const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_MAPPING_W8A8 + : WEIGHT_MAPPING; + const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_SHARD_W8A8 + : WEIGHT_SHARD; + const int index = get_mapped_index(suffix, weight_mapping); + const int local_index = expert_index % num_experts_per_partition_; + const bool is_sharded = shard_map.count(index); + + torch::Tensor tmp_tensor = is_sharded + ? get_sharded_tensor(state_dict, + name, + shard_map.at(index), + ep_local_tp_rank_, + ep_local_tp_size_) + : tensor; + + experts_weights_[suffix][local_index] = tmp_tensor.clone(); +} + +int Qwen3MoeDecoderLoader::get_mapped_index( + const std::string& name, + const std::unordered_map& mapping) { + const auto it = mapping.find(name); + if (it == mapping.end()) { + LOG(ERROR) << "Missing mapping for: " << name; + return -1; + } + + return it->second; +} + +void Qwen3MoeDecoderLoader::process_mlp_common_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_MAPPING_W8A8 + : WEIGHT_MAPPING; + const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_SHARD_W8A8 + : WEIGHT_SHARD; + const int index = get_mapped_index(name, weight_mapping); + const bool is_sharded = shard_map.count(index); + + torch::Tensor tmp_tensor = is_sharded + ? get_sharded_tensor(state_dict, + name, + shard_map.at(index), + dp_local_tp_rank_, + dp_local_tp_size_) + .to(device_) + : tensor.to(device_); + if (absl::StrContains(name, "down_proj")) { + at_weight_tensors_[index] = tmp_tensor; + } else { + shared_experts_weights_[name] = tmp_tensor; + } +} + +void Qwen3MoeDecoderLoader::process_general_weights( + const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor) { + const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_MAPPING_W8A8 + : WEIGHT_MAPPING; + const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) + ? WEIGHT_SHARD_W8A8 + : WEIGHT_SHARD; + + if (weight_mapping.find(name) == weight_mapping.end()) { + return; + } + + const int index = get_mapped_index(name, weight_mapping); + const bool is_sharded = shard_map.count(index); + torch::Tensor tmp_tensor; + int32_t tp_rank = dp_local_tp_rank_; + int32_t tp_size = dp_local_tp_size_; + + static const std::unordered_set qkv_tensor_indices = {IN_QKV_WEIGHT_1, + IN_QKV_WEIGHT_2, + IN_QKV_BIAS_1, + IN_QKV_BIAS_2, + IN_QKV_DESCALE_1, + IN_QKV_DESCALE_2, + IN_QKV_OFFSET_1, + IN_QKV_OFFSET_2, + IN_QKV_SCALE_1, + IN_QKV_SCALE_2}; + + if (qkv_tensor_indices.count(index) > 0) { + if (n_kv_heads_ < dp_local_tp_size_) { + int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_); + + tp_rank = tp_rank / repeat_times; + tp_size = n_kv_heads_; + } + } + if (is_sharded) { + tmp_tensor = get_sharded_tensor( + state_dict, name, shard_map.at(index), tp_rank, tp_size) + .to(device_); + } else { + tmp_tensor = tensor.to(device_); + } + + correct_tensor_dtype(tmp_tensor, name); + if (quantize_type_.compare("w8a8_dynamic") == 0) { + auto it = SPECIAL_MULTI_ASSIGN_W8A8.find(name); + if (it != SPECIAL_MULTI_ASSIGN_W8A8.end()) { + for (int idx : it->second) { + at_weight_tensors_[idx] = tmp_tensor; + } + return; + } + } + at_weight_tensors_[index] = tmp_tensor; +} + +torch::Tensor Qwen3MoeDecoderLoader::get_sharded_tensor( + const StateDict& state_dict, + const std::string& name, + int dim) { + if (parallel_args_.world_size() > 1) { + return state_dict.get_sharded_tensor( + name, dim, parallel_args_.rank(), parallel_args_.world_size()); + } else { + return state_dict.get_tensor(name); + } +} + +torch::Tensor Qwen3MoeDecoderLoader::get_sharded_tensor( + const StateDict& state_dict, + const std::string& name, + int dim, + int loacal_tp_rank, + int local_tp_size) { + if (local_tp_size > 1) { + return state_dict.get_sharded_tensor( + name, dim, loacal_tp_rank, local_tp_size); + } else { + return state_dict.get_tensor(name); + } +} + +torch::Tensor Qwen3MoeDecoderLoader::merge_experts_weights( + std::vector& experts_gate, + std::vector& experts_up, + bool transpose) { + for (size_t i = 0; i < experts_up.size(); ++i) { + experts_gate[i] = torch::cat({experts_gate[i], experts_up[i]}, 0); + } + torch::Tensor merged_tensor = torch::stack(experts_gate, 0).to(device_); + if (transpose) { + merged_tensor = merged_tensor.transpose(1, 2); + } + merged_tensor = merged_tensor.contiguous(); + experts_gate.clear(); + experts_up.clear(); + + return merged_tensor; +} + +void Qwen3MoeDecoderLoader::merge_loaded_weights() { + merge_experts_weights(); + at_weight_tensors_[IN_QKV_WEIGHT_0] = + torch::cat({at_weight_tensors_[IN_QKV_WEIGHT_0], + at_weight_tensors_[IN_QKV_WEIGHT_1], + at_weight_tensors_[IN_QKV_WEIGHT_2]}, + 0) + .contiguous(); + at_weight_tensors_[IN_QKV_WEIGHT_1] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_WEIGHT_2] = + torch::zeros({1}, torch::kFloat16).to(device_); + + if (quantize_type_.compare("w8a8_dynamic") == 0) { + at_weight_tensors_[IN_QKV_BIAS_0] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_BIAS_1] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_BIAS_2] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = + torch::zeros({1}, torch::kFloat16).to(device_); + + at_weight_tensors_[IN_QKV_DESCALE_0] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_DESCALE_1] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_DESCALE_2] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_ATTENTION_OUT_DESCALE] = + torch::zeros({1}, torch::kFloat16).to(device_); + + at_weight_tensors_[IN_QKV_OFFSET_0] = + torch::cat({at_weight_tensors_[IN_QKV_OFFSET_0], + at_weight_tensors_[IN_QKV_OFFSET_1], + at_weight_tensors_[IN_QKV_OFFSET_2]}, + 0) + .contiguous() + .view(-1); + at_weight_tensors_[IN_QKV_OFFSET_1] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_OFFSET_2] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] = + at_weight_tensors_[IN_ATTENTION_OUT_OFFSET].contiguous().view(-1); + + at_weight_tensors_[IN_QKV_SCALE_0] = + torch::cat({at_weight_tensors_[IN_QKV_SCALE_0], + at_weight_tensors_[IN_QKV_SCALE_1], + at_weight_tensors_[IN_QKV_SCALE_2]}, + 0) + .contiguous() + .view(-1); + at_weight_tensors_[IN_QKV_SCALE_1] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_QKV_SCALE_2] = + torch::zeros({1}, torch::kFloat16).to(device_); + at_weight_tensors_[IN_ATTENTION_OUT_SCALE] = + at_weight_tensors_[IN_ATTENTION_OUT_SCALE].contiguous().view(-1); + } +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/qwen3_moe_decoder_loader.h b/xllm/core/layers/npu/loader/qwen3_moe_decoder_loader.h new file mode 100644 index 000000000..5101a5374 --- /dev/null +++ b/xllm/core/layers/npu/loader/qwen3_moe_decoder_loader.h @@ -0,0 +1,94 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include "base_loader.h" + +namespace xllm { +namespace layer { + +class Qwen3MoeDecoderLoader : public BaseLoader { + public: + Qwen3MoeDecoderLoader(uint64_t weight_count, const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights(const std::string& prefix) const override; + void merge_loaded_weights() override; + void resize_experts_weights(int num_of_device_experts); + + protected: + std::string extract_endswith(const std::string& input); + + int extract_expert_index(const std::string& name); + + int get_mapped_index(const std::string& name, + const std::unordered_map& mapping); + + torch::Tensor get_sharded_tensor(const StateDict& state_dict, + const std::string& name, + int dim); + torch::Tensor get_sharded_tensor(const StateDict& state_dict, + const std::string& name, + int dim, + int local_tp_rank, + int local_tp_size); + + void process_mlp_common_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + void process_general_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + void merge_experts_weights(); + + torch::Tensor merge_experts_weights(std::vector& experts_up, + std::vector& experts_gate, + bool transpose = false); + + torch::Tensor merge_experts_weights(std::vector& experts, + bool transpose = false); + + void process_expert_weights(const StateDict& state_dict, + const std::string& name, + const torch::Tensor& tensor); + + int32_t ep_size_; + int32_t num_experts_; + int32_t num_experts_per_partition_; + int32_t ep_local_tp_size_; + int32_t ep_local_tp_rank_; + int32_t start_expert_id_; + int32_t end_expert_id_; + int32_t ep_rank_; + int32_t n_kv_heads_; + + int32_t dp_size_; + int32_t dp_local_tp_size_; + int32_t dp_rank_; + int32_t dp_local_tp_rank_; + + std::unordered_map shared_experts_weights_; + std::unordered_map> experts_weights_; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/qwen3_vision_encoder_loader.cpp b/xllm/core/layers/npu/loader/qwen3_vision_encoder_loader.cpp new file mode 100644 index 000000000..b4d14acbf --- /dev/null +++ b/xllm/core/layers/npu/loader/qwen3_vision_encoder_loader.cpp @@ -0,0 +1,162 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#ifdef TORCH_HIGHER_THAN_PTA6 +#include +#include +#else +#include +#include +#endif + +#include + +#include "qwen3_vision_encoder_loader.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/core/npu/NPUException.h" + +namespace xllm { +namespace layer { + +enum VisionEncoderLayerTensorId : int { + IN_INPUT_NORM_WEIGHT = 0, + IN_INPUT_NORM_BIAS, + IN_POST_NORM_WEIGHT, + IN_POST_NORM_BIAS, + IN_QKV_WEIGHT, + IN_QKV_BIAS, + IN_WATTENTION_OUT_WEIGHT, + IN_WATTENTION_OUT_BIAS, + IN_LINEAR_FC1_WEIGHT, + IN_LINEAR_FC1_BIAS, + IN_LINEAR_FC2_WEIGHT, + IN_LINEAR_FC2_BIAS, + IN_VISION_Q_WEIGHT, + IN_VISION_Q_BIAS, + IN_VISION_K_WEIGHT, + IN_VISION_K_BIAS, + IN_VISION_V_WEIGHT, + IN_VISION_V_BIAS +}; + +static std::vector> WEIGHT_MAPPING = { + {IN_INPUT_NORM_WEIGHT, "norm1.weight"}, + {IN_INPUT_NORM_BIAS, "norm1.bias"}, + {IN_POST_NORM_WEIGHT, "norm2.weight"}, + {IN_POST_NORM_BIAS, "norm2.bias"}, + {IN_QKV_WEIGHT, "attn.qkv.weight"}, + {IN_QKV_BIAS, "attn.qkv.bias"}, + {IN_WATTENTION_OUT_WEIGHT, "attn.proj.weight"}, + {IN_WATTENTION_OUT_BIAS, "attn.proj.bias"}, + {IN_LINEAR_FC1_WEIGHT, "mlp.linear_fc1.weight"}, + {IN_LINEAR_FC1_BIAS, "mlp.linear_fc1.bias"}, + {IN_LINEAR_FC2_WEIGHT, "mlp.linear_fc2.weight"}, + {IN_LINEAR_FC2_BIAS, "mlp.linear_fc2.bias"}}; + +// {weight,dim} +static std::map WEIGHT_SHARD = { + {IN_WATTENTION_OUT_WEIGHT, 1}, + {IN_LINEAR_FC1_WEIGHT, 0}, + {IN_LINEAR_FC1_BIAS, 0}, + {IN_LINEAR_FC2_WEIGHT, 1}, +}; + +Qwen3VisionEncoderLoader::Qwen3VisionEncoderLoader(uint64_t weight_count, + const ModelContext& context) + : BaseLoader(weight_count, context) { + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + auto options = context.get_tensor_options(); + encode_param_rank = parallel_args.rank(); + encode_param_worldSize = parallel_args.world_size(); + at_weight_tensors_.resize(weight_count); + dtype_ = torch::typeMetaToScalarType(options.dtype()); + device_id_ = options.device().index(); + at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); + for (int i = 0; i < weight_count; ++i) { + at_weight_tensors_[i] = torch::zeros({1}).to(options); + } +} + +void Qwen3VisionEncoderLoader::load_state_dict(const StateDict& state_dict) { + for (const auto& [index, name] : WEIGHT_MAPPING) { + if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { + set_weight(state_dict, name, index, WEIGHT_SHARD[index]); + } else { + set_weight(state_dict, name, index); + } + } +} + +void Qwen3VisionEncoderLoader::verify_loaded_weights() const { + for (const auto& [index, name] : WEIGHT_MAPPING) { + CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) + << "weight is not loaded for " << name; + } +} + +void Qwen3VisionEncoderLoader::merge_loaded_weights() { + // spilt pack qkv weight when enable tp + get_weights_col_packed_qkv(); + if (encode_param_worldSize > 1) { + // merge qkv weight + auto new_qkv_weight = torch::cat({at_weight_tensors_[IN_VISION_Q_WEIGHT], + at_weight_tensors_[IN_VISION_K_WEIGHT], + at_weight_tensors_[IN_VISION_V_WEIGHT]}, + 0); + at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight; + at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros({1}).to(device_); + + // merge qkv bias + auto new_qkv_bias = torch::cat({at_weight_tensors_[IN_VISION_Q_BIAS], + at_weight_tensors_[IN_VISION_K_BIAS], + at_weight_tensors_[IN_VISION_V_BIAS]}, + 0); + at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias; + at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros({1}).to(device_); + at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros({1}).to(device_); + } +} + +// tp spilt weight +void Qwen3VisionEncoderLoader::get_weights_col_packed_qkv() { + int rank = encode_param_rank; + int worldSize = encode_param_worldSize; + // split qkv weight + qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0); + qkv_bias = torch::chunk(at_weight_tensors_[IN_QKV_BIAS], 3, 0); + // weight + at_weight_tensors_[IN_VISION_Q_WEIGHT] = + (qkv_weight[0].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_K_WEIGHT] = + (qkv_weight[1].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_V_WEIGHT] = + (qkv_weight[2].chunk(worldSize, 0))[rank]; + // bias + at_weight_tensors_[IN_VISION_Q_BIAS] = + (qkv_bias[0].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_K_BIAS] = + (qkv_bias[1].chunk(worldSize, 0))[rank]; + at_weight_tensors_[IN_VISION_V_BIAS] = + (qkv_bias[2].chunk(worldSize, 0))[rank]; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/qwen3_vision_encoder_loader.h b/xllm/core/layers/npu/loader/qwen3_vision_encoder_loader.h new file mode 100644 index 000000000..86d0d5bba --- /dev/null +++ b/xllm/core/layers/npu/loader/qwen3_vision_encoder_loader.h @@ -0,0 +1,49 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "base_loader.h" + +namespace xllm { +namespace layer { + +class Qwen3VisionEncoderLoader : public BaseLoader { + public: + Qwen3VisionEncoderLoader(uint64_t weight_count, const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights() const override; + void merge_loaded_weights() override; + + private: + void get_weights_col_packed_qkv(); + + protected: + std::string model_name_; + at::Tensor cu_seqlen_; + at::Tensor at_placeholder_; + std::vector qkv_weight; + std::vector qkv_bias; + int device_id_; + int encode_param_rank; + int encode_param_worldSize; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/rms_norm_loader.cpp b/xllm/core/layers/npu/loader/rms_norm_loader.cpp new file mode 100644 index 000000000..dc7b9bb99 --- /dev/null +++ b/xllm/core/layers/npu/loader/rms_norm_loader.cpp @@ -0,0 +1,41 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "rms_norm_loader.h" + +namespace xllm { +namespace layer { + +RMSNORMLoader::RMSNORMLoader(uint64_t weight_count, const ModelContext& context) + : BaseLoader(weight_count, context) { + at_weight_tensors_.resize(1); + + auto options = context.get_tensor_options(); + dtype_ = torch::typeMetaToScalarType(options.dtype()); + at_weight_tensors_[0] = torch::zeros({1}).to(options); +} + +void RMSNORMLoader::load_state_dict(const StateDict& state_dict) { + set_weight(state_dict, "weight", 0); + at_weight_tensors_[0] = at_weight_tensors_[0].to(dtype_); +} + +void RMSNORMLoader::verify_loaded_weights(const std::string& weight_str) const { + CHECK(at_weight_tensors_[0].sizes() != std::vector({1})) + << "final norm weight is not loaded for " << weight_str; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/rms_norm_loader.h b/xllm/core/layers/npu/loader/rms_norm_loader.h new file mode 100644 index 000000000..2fee83859 --- /dev/null +++ b/xllm/core/layers/npu/loader/rms_norm_loader.h @@ -0,0 +1,40 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "base_loader.h" + +namespace xllm { +namespace layer { + +class RMSNORMLoader : public BaseLoader { + public: + RMSNORMLoader(uint64_t weight_count, const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; + + void verify_loaded_weights(const std::string& weight_str) const override; + + protected: + int rank_id_; + torch::ScalarType dtype_; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/siglip_encoder_loader.cpp b/xllm/core/layers/npu/loader/siglip_encoder_loader.cpp new file mode 100644 index 000000000..ce5295883 --- /dev/null +++ b/xllm/core/layers/npu/loader/siglip_encoder_loader.cpp @@ -0,0 +1,55 @@ +#include "siglip_encoder_loader.h" + +namespace xllm { +namespace layer { + +SiglipEncoderUpLoader::SiglipEncoderUpLoader(const ModelContext& context) + : BaseLoader(0, context) { + options_ = context.get_tensor_options(); +} + +void SiglipEncoderUpLoader::load_state_dict(const StateDict& state_dict) { + const std::set key_names = {"layer_norm1.weight", + "layer_norm1.bias", + "self_attn.q_proj.weight", + "self_attn.q_proj.bias", + "self_attn.k_proj.weight", + "self_attn.k_proj.bias", + "self_attn.v_proj.weight", + "self_attn.v_proj.bias"}; + + for (const auto& [name, tensor] : state_dict) { + if (key_names.find(name) == key_names.end()) continue; + + auto weight_npu = tensor.to(options_); + + weights_map_[name] = weight_npu; + } +} + +SiglipEncoderDownLoader::SiglipEncoderDownLoader(const ModelContext& context) + : BaseLoader(0, context) { + options_ = context.get_tensor_options(); +} + +void SiglipEncoderDownLoader::load_state_dict(const StateDict& state_dict) { + const std::set key_names = {"self_attn.out_proj.weight", + "self_attn.out_proj.bias", + "layer_norm2.weight", + "layer_norm2.bias", + "mlp.fc1.weight", + "mlp.fc1.bias", + "mlp.fc2.weight", + "mlp.fc2.bias"}; + + for (const auto& [name, tensor] : state_dict) { + if (key_names.find(name) == key_names.end()) continue; + + auto weight_npu = tensor.to(options_); + + weights_map_[name] = weight_npu; + } +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/siglip_encoder_loader.h b/xllm/core/layers/npu/loader/siglip_encoder_loader.h new file mode 100644 index 000000000..fae5a60a5 --- /dev/null +++ b/xllm/core/layers/npu/loader/siglip_encoder_loader.h @@ -0,0 +1,40 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#pragma once + +#include +#include + +#include "base_loader.h" + +namespace xllm { +namespace layer { + +class SiglipEncoderUpLoader : public BaseLoader { + public: + explicit SiglipEncoderUpLoader(const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; +}; + +class SiglipEncoderDownLoader : public BaseLoader { + public: + explicit SiglipEncoderDownLoader(const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/word_embedding_loader.cpp b/xllm/core/layers/npu/loader/word_embedding_loader.cpp new file mode 100644 index 000000000..1c10476c1 --- /dev/null +++ b/xllm/core/layers/npu/loader/word_embedding_loader.cpp @@ -0,0 +1,40 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "word_embedding_loader.h" + +namespace xllm { +namespace layer { + +WordEmbeddingLoader::WordEmbeddingLoader(uint64_t weight_count, + const ModelContext& context) + : BaseLoader(weight_count, context) {} + +void WordEmbeddingLoader::load_state_dict(const StateDict& state_dict) { + if (dp_size_ > 1) { + set_weight( + state_dict, "weight", 0, 1, dp_local_tp_rank_, dp_local_tp_size_); + } else { + set_weight(state_dict, "weight", 0, 1); + } +} + +void WordEmbeddingLoader::verify_loaded_weights( + const std::string& weight_str) const { + CHECK(at_weight_tensors_[0].sizes() != std::vector({1})) + << "weight is not loaded for " << weight_str; +} + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/loader/word_embedding_loader.h b/xllm/core/layers/npu/loader/word_embedding_loader.h new file mode 100644 index 000000000..fa27adfb0 --- /dev/null +++ b/xllm/core/layers/npu/loader/word_embedding_loader.h @@ -0,0 +1,35 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include + +#include "core/layers/npu/npu_base_layer.h" + +namespace xllm { +namespace layer { + +class WordEmbeddingLoader : public BaseLoader { + public: + WordEmbeddingLoader(uint64_t weight_count, const ModelContext& context); + + void load_state_dict(const StateDict& state_dict) override; + void verify_loaded_weights(const std::string& prefix) const override; +}; + +} // namespace layer +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/npu_base_layer.h b/xllm/core/layers/npu/npu_base_layer.h index 2b5319ebc..c29f098d8 100644 --- a/xllm/core/layers/npu/npu_base_layer.h +++ b/xllm/core/layers/npu/npu_base_layer.h @@ -39,6 +39,7 @@ limitations under the License. #include "framework/model_context.h" #include "framework/state_dict/state_dict.h" #include "framework/xtensor/xtensor.h" +#include "loader/base_loader.h" #include "pytorch/adapter/utils/utils.h" #include "pytorch/adapter/workspace/workspace.h" @@ -115,11 +116,30 @@ class BaseLayer : public torch::nn::Module { aclrtEvent* event, std::atomic* event_flag); - virtual void load_state_dict(const StateDict& state_dict) {}; - - virtual void verify_loaded_weights() const {}; - - virtual void merge_loaded_weights() {}; + virtual void load_state_dict(const StateDict& state_dict) { + if (loader_) { + loader_->load_state_dict(state_dict); + } + }; + + virtual void verify_loaded_weights() const { + if (loader_) { + loader_->verify_loaded_weights(); + } + }; + + virtual void verify_loaded_weights(const std::string& prefix) const { + if (loader_) { + loader_->verify_loaded_weights(prefix); + } + }; + + virtual void merge_loaded_weights() { + if (loader_) { + loader_->merge_loaded_weights(); + } + init_layer(); + }; virtual int64_t init_layer() { return 0; }; @@ -150,6 +170,7 @@ class BaseLayer : public torch::nn::Module { atb::Tensor XTensor2Tensor(const std::shared_ptr& xtensor); protected: + std::unique_ptr loader_ = nullptr; std::vector at_weight_tensors_; at::Device device_; std::string name_; diff --git a/xllm/core/layers/npu/npu_block_copy_impl.h b/xllm/core/layers/npu/npu_block_copy_impl.h index e5c049227..082c62b8c 100644 --- a/xllm/core/layers/npu/npu_block_copy_impl.h +++ b/xllm/core/layers/npu/npu_block_copy_impl.h @@ -47,8 +47,6 @@ class BlockCopyImpl : public BaseLayer { void load_state_dict(const StateDict& state_dict) {}; - void verify_loaded_weights(const std::string weight_str) const {}; - void merge_loaded_weights(); int64_t init_layer(); diff --git a/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp b/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp index 89023ce76..36f71f72e 100644 --- a/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp +++ b/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp @@ -45,39 +45,24 @@ ColumnParallelLinearImpl::ColumnParallelLinearImpl(const ModelContext& context) : BaseLayer(context) { param_from_args( linear_param_, context.get_model_args(), context.get_parallel_args()); - at_weight_tensors_.resize(1); atb_weight_tensors_.resize(1); at_out_tensors_.resize(1); auto options = context.get_tensor_options(); dtype_ = c10::typeMetaToScalarType(options.dtype()); - at_weight_tensors_[0] = torch::zeros({1}).to(options); tensor_placeholder_ = torch::zeros({1}).to(options); placeholder_ = atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_); -} - -void ColumnParallelLinearImpl::verify_loaded_weights( - const std::string weight_str) const { - CHECK(at_weight_tensors_[0].sizes() != std::vector({1})) - << "weight is not loaded for " << weight_str; + loader_ = std::make_unique(1, context); } void ColumnParallelLinearImpl::merge_loaded_weights() { + auto& at_weight_tensors = loader_->get_at_weight_tensors(); + atb_weight_tensors_[0] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[0]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[0]); init_layer(); } -void ColumnParallelLinearImpl::load_state_dict(const StateDict& state_dict) { - if (dp_size_ > 1) { - set_weight( - state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_); - } else { - set_weight(state_dict, "weight", 0, 0); - } - at_weight_tensors_[0] = at_weight_tensors_[0].to(dtype_); -} - int64_t ColumnParallelLinearImpl::init_layer() { name_ = "atb_parallel_linear_layer"; model_name_ = "Atb Parallel Linear"; diff --git a/xllm/core/layers/npu/npu_column_parallel_linear_impl.h b/xllm/core/layers/npu/npu_column_parallel_linear_impl.h index 54836cb06..9391d8d16 100644 --- a/xllm/core/layers/npu/npu_column_parallel_linear_impl.h +++ b/xllm/core/layers/npu/npu_column_parallel_linear_impl.h @@ -31,6 +31,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "framework/model_context.h" #include "framework/state_dict/state_dict.h" +#include "loader/column_parallel_linear_loader.h" #include "nlohmann/json.hpp" #include "npu_base_layer.h" #include "pytorch/adapter/utils/utils.h" @@ -51,10 +52,6 @@ class ColumnParallelLinearImpl : public BaseLayer { ~ColumnParallelLinearImpl() {}; - virtual void load_state_dict(const StateDict& state_dict) override; - - void verify_loaded_weights(const std::string weight_str) const; - virtual void merge_loaded_weights() override; virtual int64_t init_layer() override; diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp index 5f6454267..8b54de329 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.cpp @@ -129,127 +129,6 @@ enum DecoderLayerTensorId : int { static const uint64_t WEIGHT_COUNT_PER_LAYER = 84; -static std::vector> WEIGHT_MAPPING = {}; - -static const std::unordered_map WEIGHT_MAPPING_W8A8 = { - {"input_layernorm.weight", IN_INPUT_NORM_WEIGHT}, - {"input_layernorm.bias", IN_INPUT_NORM_BIAS}, - - {"self_attn.q_a_proj.weight", IN_Q_PROJ_A_WEIGHT}, - {"self_attn.q_a_proj.quant_bias", IN_Q_PROJ_A_BIAS}, - {"self_attn.q_a_proj.deq_scale", IN_Q_PROJ_A_DESCALE}, - {"self_attn.q_a_proj.input_offset", IN_Q_PROJ_A_OFFSET}, - {"self_attn.q_a_proj.input_scale", IN_Q_PROJ_A_SCALE}, - {"self_attn.q_a_layernorm.weight", IN_Q_PROJ_A_LAYERNORM_WEIGHT}, - {"self_attn.q_a_layernorm.bias", IN_Q_PROJ_A_LAYERNORM_BIAS}, - - {"self_attn.q_proj.weight", IN_Q_PROJ_B_WEIGHT}, - {"self_attn.q_b_proj.weight", IN_Q_PROJ_B_WEIGHT}, - {"self_attn.q_b_proj.quant_bias", IN_Q_PROJ_B_BIAS}, - {"self_attn.q_b_proj.input_scale", IN_Q_PROJ_B_SCALE}, - {"self_attn.q_b_proj.deq_scale", IN_Q_PROJ_B_DESCALE}, - {"self_attn.q_b_proj.input_offset", IN_Q_PROJ_B_OFFSET}, - - {"self_attn.kv_a_proj_with_mqa.weight", IN_KV_PROJ_WITH_MQA_WEIGHT}, - {"self_attn.kv_a_proj_with_mqa.quant_bias", IN_KV_PROJ_WITH_MQA_BIAS}, - {"self_attn.kv_a_proj_with_mqa.deq_scale", IN_KV_PROJ_WITH_MQA_DESCALE}, - {"self_attn.kv_a_proj_with_mqa.input_offset", IN_KV_PROJ_WITH_MQA_OFFSET}, - {"self_attn.kv_a_proj_with_mqa.input_scale", IN_KV_PROJ_WITH_MQA_SCALE}, - - {"self_attn.kv_a_layernorm.weight", IN_KV_PROJ_A_LAYERNORM_WEIGHT}, - {"self_attn.kv_a_layernorm.bias", IN_KV_PROJ_A_LAYERNORM_BIAS}, - - {"self_attn.kv_b_proj.weight", IN_K_PROJ_B_FOR_Q_WEIGHT}, // merge - // {"self_attn.kv_b_proj.weight", IN_V_PROJ_B_FOR_O_WEIGHT}, // merge - - {"self_attn.o_proj.weight", IN_ATTENTION_OUT_WEIGHT}, - {"self_attn.o_proj.quant_bias", IN_ATTENTION_OUT_BIAS}, - {"self_attn.o_proj.deq_scale", IN_ATTENTION_OUT_DESCALE}, - {"self_attn.o_proj.input_offset", IN_ATTENTION_OUT_OFFSET}, - {"self_attn.o_proj.input_scale", IN_ATTENTION_OUT_SCALE}, - - {"post_attention_layernorm.weight", IN_SELFATTENTION_OUT_NORM_WEIGHT}, - {"post_attention_layernorm.bias", IN_SELFATTENTION_OUT_NORM_BIAS}, - - {"mlp.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - {"mlp.gate_proj.weight_offset", IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, - {"mlp.gate_proj.weight_scale", IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, - - {"mlp.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - {"mlp.up_proj.weight_offset", IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, - {"mlp.up_proj.weight_scale", IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, - - {"mlp.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, - {"mlp.down_proj.weight_offset", IN_MLP_DOWN_OFFSET_SHARED_EXPERT}, - {"mlp.down_proj.weight_scale", IN_MLP_DOWN_SCALE_SHARED_EXPERT}, - - {"mlp.shared_experts.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - {"mlp.shared_experts.gate_proj.weight_offset", - IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, - {"mlp.shared_experts.gate_proj.weight_scale", - IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, - - {"mlp.shared_experts.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - {"mlp.shared_experts.up_proj.weight_offset", - IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, - {"mlp.shared_experts.up_proj.weight_scale", - IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, - - {"mlp.shared_experts.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, - {"mlp.shared_experts.down_proj.weight_offset", - IN_MLP_DOWN_OFFSET_SHARED_EXPERT}, - {"mlp.shared_experts.down_proj.weight_scale", - IN_MLP_DOWN_SCALE_SHARED_EXPERT}, - - {"mlp.gate.weight", IN_BLOCK_SPARSE_MOE_GATE_WEIGHT}, - {"mlp.gate.e_score_correction_bias", IN_BLOCK_SPARSE_MOE_GATE_BIAS}, - - {"gate_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, - {"gate_proj.weight_offset", IN_MLP_GATEUP_OFFSET_EXPERT}, - {"gate_proj.weight_scale", IN_MLP_GATEUP_SCALE_EXPERT}, - {"up_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, - {"up_proj.weight_offset", IN_MLP_GATEUP_OFFSET_EXPERT}, - {"up_proj.weight_scale", IN_MLP_GATEUP_SCALE_EXPERT}, - - {"down_proj.weight", IN_MLP_DOWN_WEIGHT_EXPERT}, - {"down_proj.weight_offset", IN_MLP_DOWN_OFFSET_EXPERT}, - {"down_proj.weight_scale", IN_MLP_DOWN_SCALE_EXPERT}, -}; - -static const std::map WEIGHT_SHARD = {}; - -static const std::map WEIGHT_SHARD_W8A8 = { - {IN_Q_PROJ_B_WEIGHT, 0}, - {IN_Q_PROJ_B_BIAS, 0}, - {IN_Q_PROJ_B_DESCALE, 0}, - {IN_K_PROJ_B_FOR_Q_WEIGHT, 0}, - {IN_V_PROJ_B_FOR_O_WEIGHT, 0}, - {IN_ATTENTION_OUT_WEIGHT, 1}, - {IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, 0}, - {IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, 0}, - {IN_MLP_GATEUP_SCALE_SHARED_EXPERT, 0}, - {IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, 1}, - {IN_MLP_GATEUP_WEIGHT_EXPERT, 0}, - {IN_MLP_GATEUP_OFFSET_EXPERT, 0}, - {IN_MLP_GATEUP_SCALE_EXPERT, 0}, - {IN_MLP_DOWN_WEIGHT_EXPERT, 1}, -}; - -static std::vector SQUEEZE_WEIGHT_VEC = { - IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, - IN_MLP_GATEUP_SCALE_SHARED_EXPERT, - IN_MLP_DOWN_OFFSET_SHARED_EXPERT, - IN_MLP_DOWN_SCALE_SHARED_EXPERT}; - -static std::vector LINEAR_FOR_ROPE = { - "self_attn.q_b_proj.weight", - "self_attn.q_b_proj.quant_bias", - "self_attn.q_b_proj.deq_scale", - "self_attn.kv_a_proj_with_mqa.weight", - "self_attn.kv_a_proj_with_mqa.quant_bias", - "self_attn.kv_a_proj_with_mqa.deq_scale", -}; - DeepseekV2DecoderLayerImpl::DeepseekV2DecoderLayerImpl( const ModelContext& context, const int32_t layer_id) @@ -279,6 +158,7 @@ DeepseekV2DecoderLayerImpl::DeepseekV2DecoderLayerImpl( auto parallel_args = context.get_parallel_args(); auto model_args = context.get_model_args(); auto options = context.get_tensor_options(); + rank_ = parallel_args.rank(); first_k_dense_replace_ = model_args.first_k_dense_replace(); n_layers_ = model_args.n_layers(); @@ -307,13 +187,27 @@ DeepseekV2DecoderLayerImpl::DeepseekV2DecoderLayerImpl( param_from_args(decode_mla_param_, model_args, parallel_args, false); decode_mla_param_.enableCustomizeMla = FLAGS_enable_customize_mla_kernel; + loader_ = std::make_unique( + WEIGHT_COUNT_PER_LAYER, + context, + layer_id_, + prefill_param_.firstKDenseReplace, + prefill_param_.numOfDeviceExperts, + prefill_param_.qkRopeHeadDim, + prefill_param_.numAttentionHeadsPerRank, + decode_param_.worldSize, + qk_nope_head_dim_, + kv_lora_rank_, + num_key_value_heads_, + v_head_dim_, + prefill_param_.isBF16, + decode_param_.isBF16); initialize_tensors(options); } void DeepseekV2DecoderLayerImpl::initialize_tensors( const torch::TensorOptions& options) { // initializ placeholder - at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; int_tensor_placeholder_ = torch::ones({1}).to(torch::kInt32).to(device_); @@ -322,7 +216,6 @@ void DeepseekV2DecoderLayerImpl::initialize_tensors( torch::zeros({1, 1}).to(torch::kInt32).to(device_); tensor_placeholder_ = torch::zeros({1}).to(options); - reserve_experts_weights(prefill_param_.numOfDeviceExperts); expert_group_ = torch::arange(1024, torch::kInt32).to(device_); one_hot_ = torch::tensor({1}, torch::kInt32).to(device_); zero_hot_ = torch::tensor({0}, torch::kInt32).to(device_); @@ -331,12 +224,11 @@ void DeepseekV2DecoderLayerImpl::initialize_tensors( at_in_device_expert_count_ = torch::tensor({num_experts_per_partition_ - 1}, torch::kInt64) .to(device_); - initialize_weight_tensors(options); - initialize_device_expert_list(decode_param_.worldSize, - num_experts_per_partition_); + + auto& device_expert_list = loader_->get_device_expert_list(); if (FLAGS_enable_eplb) { auto layer_expert_routing_map_ = - build_expert_routing_map(device_expert_list_); + build_expert_routing_map(device_expert_list); std::vector tensors_vec; for (int i = 0; i < n_layers_ - first_k_dense_replace_; i++) { tensors_vec.emplace_back(layer_expert_routing_map_); @@ -345,23 +237,6 @@ void DeepseekV2DecoderLayerImpl::initialize_tensors( } } -void DeepseekV2DecoderLayerImpl::initialize_device_expert_list( - int num_device, - int num_device_expert) { - int32_t num_device_route_expert = num_device_expert; - if (FLAGS_enable_eplb) { - num_device_route_expert = num_device_expert - redundant_experts_num_; - } - for (int i = 0; i < num_device * num_device_route_expert; ++i) { - device_expert_list_.emplace_back(i); - if (FLAGS_enable_eplb && (i + 1) % num_device_route_expert == 0) { - for (int redundant_expert = 0; redundant_expert < redundant_experts_num_; - ++redundant_expert) - device_expert_list_.emplace_back(i); - } - } -} - void DeepseekV2DecoderLayerImpl::param_from_args( atb_speed::deepseekV2::DecoderLayerParam& param, const ModelArgs& args, @@ -375,40 +250,6 @@ void DeepseekV2DecoderLayerImpl::param_from_args( initialize_kimi_k2_parameters(param, args, is_prefill); } -void DeepseekV2DecoderLayerImpl::reserve_experts_weights( - int num_of_device_experts) { - experts_weights_.clear(); - std::vector weight_names = { - "gate_proj.weight", "up_proj.weight", "down_proj.weight"}; - if (quantize_type_ == "w8a8_dynamic") { - weight_names.emplace_back("gate_proj.weight_offset"); - weight_names.emplace_back("up_proj.weight_offset"); - weight_names.emplace_back("down_proj.weight_offset"); - weight_names.emplace_back("gate_proj.weight_scale"); - weight_names.emplace_back("up_proj.weight_scale"); - weight_names.emplace_back("down_proj.weight_scale"); - } - std::lock_guard lock(experts_mutex_); - for (const auto& weight_name : weight_names) { - experts_weights_[weight_name] = - std::vector(num_of_device_experts); - ; - } -} - -void DeepseekV2DecoderLayerImpl::initialize_weight_tensors( - const torch::TensorOptions& options) { - for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - at_weight_tensors_[i] = torch::zeros({1}).to(options); - } - if (FLAGS_enable_eplb) { - const int64_t size = - 50LL * 1024LL * 1024LL * int64_t(n_layers_ - first_k_dense_replace_); - shared_buffer_ = std::make_unique( - num_experts_, n_layers_ - first_k_dense_replace_, size); - } -} - void DeepseekV2DecoderLayerImpl::initialize_basic_parameters( atb_speed::deepseekV2::DecoderLayerParam& param, const ModelArgs& args, @@ -659,692 +500,70 @@ void DeepseekV2DecoderLayerImpl::initialize_quantization_parameters( } } -void DeepseekV2DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { - for (const auto& [name, tensor] : state_dict) { - bool is_sharded = false; - int index = 0; - - if (absl::EndsWith(name, "self_attn.kv_b_proj.weight")) { - index = WEIGHT_MAPPING_W8A8.at(name); - set_kv_weight(state_dict, name, index, WEIGHT_SHARD_W8A8.at(index)); - continue; - } - - if (absl::StartsWith(name, "mlp.experts")) { - process_expert_weights(state_dict, name, tensor); - continue; - } - - if (absl::StartsWith(name, "mlp.shared_experts")) { - process_shared_expert_weights(state_dict, name, tensor); - continue; - } - - if (absl::StartsWith(name, "mlp") && !absl::StrContains(name, "gate.")) { - process_mlp_common_weights(state_dict, name, tensor); - continue; - } - - process_general_weights(state_dict, name, tensor); - } -} - -int DeepseekV2DecoderLayerImpl::get_mapped_index( - const std::string& name, - const std::unordered_map& mapping) { - const auto it = mapping.find(name); - if (it == mapping.end()) { - LOG(WARNING) << "Parameter '" << name - << "' not found in mapping and will not be used."; - return -1; - } - return it->second; -} - -std::string DeepseekV2DecoderLayerImpl::get_expert_shm_key( - int32_t layer_id, - int32_t expert_index, - const std::string& suffix) { - std::string shm_key = - "layer_" + std::to_string(layer_id - first_k_dense_replace_) + "_" + - "expert_" + std::to_string(expert_index) + "_" + suffix; - return shm_key; -} - -void DeepseekV2DecoderLayerImpl::process_expert_weights( - const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - // Step 1: Early checks and basic info extraction - int expert_index = extract_expert_index(name); - const std::string suffix = extract_endswith(name); - const int index = get_mapped_index(suffix, WEIGHT_MAPPING_W8A8); - if (index == -1) { - return; - } - - const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); - const bool needs_eplb = FLAGS_enable_eplb && (rank_ % localWorldSize_ == - expert_index % localWorldSize_); - - // Step 2: Check if expert is in partition - const int start_idx = ep_rank_ * num_experts_per_partition_; - const int end_idx = (ep_rank_ + 1) * num_experts_per_partition_; - const int safe_end = - std::min(end_idx, static_cast(device_expert_list_.size())); - - auto it = std::find(device_expert_list_.begin() + start_idx, - device_expert_list_.begin() + safe_end, - expert_index); - const bool in_partition = it != device_expert_list_.begin() + safe_end; - - // Early return if neither EPLB nor partition needs this expert - if (!needs_eplb && !in_partition) { - return; - } - - // Step 3: Process tensor - torch::Tensor processed_tensor; - { - std::lock_guard lock(experts_mutex_); - processed_tensor = is_sharded - ? get_sharded_tensor(state_dict, - name, - WEIGHT_SHARD_W8A8.at(index), - ep_local_tp_rank_, - ep_local_tp_size_) - : tensor; - - if (!decode_param_.isBF16) { - if (absl::EndsWith(name, "_offset")) { - processed_tensor = processed_tensor.to(torch::kFloat16); - } else if (absl::EndsWith(name, "_scale")) { - processed_tensor = processed_tensor.to(torch::kFloat32); - } - } - } - - // Step 4: Handle EPLB case - if (needs_eplb) { - std::lock_guard lock(experts_mutex_); - std::string shm_key = get_expert_shm_key(layer_id_, expert_index, suffix); - shared_buffer_->add_tensor(expert_index, - layer_id_ - first_k_dense_replace_, - shm_key, - processed_tensor.contiguous()); - } - - // Step 5: Handle partition case - if (in_partition) { - std::vector matches_pos; - for (auto iter = it; iter != device_expert_list_.begin() + safe_end; - ++iter) { - if (*iter == expert_index) { - matches_pos.emplace_back( - std::distance(device_expert_list_.begin(), iter) - start_idx); - } - } - - if (!matches_pos.empty()) { - std::lock_guard lock(experts_mutex_); - for (auto pos : matches_pos) { - experts_weights_[suffix][pos] = processed_tensor.clone(); - } - } - } -} - -void DeepseekV2DecoderLayerImpl::process_shared_expert_weights( - const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - torch::Tensor tmp_tensor; - std::lock_guard lock(shared_experts_mutex_); - const int index = get_mapped_index(name, WEIGHT_MAPPING_W8A8); - if (index == -1) { - return; - } - if (FLAGS_expert_parallel_degree == 2) { - tmp_tensor = tensor.to(device_); - } else { - const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); - tmp_tensor = is_sharded ? get_sharded_tensor( - state_dict, name, WEIGHT_SHARD_W8A8.at(index)) - .to(device_) - : tensor.to(device_); - } - if (absl::StrContains(name, "down_proj")) { - at_weight_tensors_[index] = tmp_tensor; - } else { - shared_experts_weights_[name] = tmp_tensor; - } -} - -void DeepseekV2DecoderLayerImpl::process_mlp_common_weights( - const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - const int index = get_mapped_index(name, WEIGHT_MAPPING_W8A8); - if (index == -1) { - return; - } - const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); - std::lock_guard lock(shared_experts_mutex_); - - torch::Tensor tmp_tensor = - is_sharded ? get_sharded_tensor(state_dict, - name, - WEIGHT_SHARD_W8A8.at(index), - dp_local_tp_rank_, - dp_local_tp_size_) - .to(device_) - : tensor.to(device_); - if (absl::StrContains(name, "down_proj")) { - at_weight_tensors_[index] = tmp_tensor; - } else { - shared_experts_weights_[name] = tmp_tensor; - } -} - -void DeepseekV2DecoderLayerImpl::process_general_weights( - const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - const int index = get_mapped_index(name, WEIGHT_MAPPING_W8A8); - if (index == -1) { - return; - } - const bool is_sharded = WEIGHT_SHARD_W8A8.count(index); - torch::Tensor tmp_tensor; - - tmp_tensor = is_sharded ? get_sharded_tensor(state_dict, - name, - WEIGHT_SHARD_W8A8.at(index), - dp_local_tp_rank_, - dp_local_tp_size_) - .to(device_) - : tensor.to(device_); - - correct_tensor_dtype(tmp_tensor, name); - at_weight_tensors_[index] = tmp_tensor; -} - -void DeepseekV2DecoderLayerImpl::set_kv_weight(const StateDict& state_dict, - const std::string& tensor_name, - int weight_position, - int dim) { - torch::Tensor mutable_tensor; - if (parallel_args_.world_size() <= 1) { - mutable_tensor = state_dict.get_tensor(tensor_name).to(device_); - correct_tensor_dtype(mutable_tensor, tensor_name); - } else { - mutable_tensor = - get_sharded_tensor( - state_dict, tensor_name, dim, dp_local_tp_rank_, dp_local_tp_size_) - .to(device_); - // mutable_tensor = get_sharded_tensor(state_dict, tensor_name, dim); - correct_tensor_dtype(mutable_tensor, tensor_name); - } - - torch::Tensor kv_b_proj_weight = - mutable_tensor.reshape({num_key_value_heads_ / dp_local_tp_size_, - qk_nope_head_dim_ + v_head_dim_, - kv_lora_rank_}); - torch::Tensor k_b_proj_preprocessed = - kv_b_proj_weight.slice(1, 0, qk_nope_head_dim_).contiguous(); - torch::Tensor v_b_proj_preprocessed = - kv_b_proj_weight - .slice(1, qk_nope_head_dim_, qk_nope_head_dim_ + v_head_dim_) - .transpose(1, 2) - .contiguous(); - at_weight_tensors_[weight_position] = k_b_proj_preprocessed.to(device_); - at_weight_tensors_[weight_position + 6] = v_b_proj_preprocessed.to(device_); -} - -void DeepseekV2DecoderLayerImpl::preprocess_linear_for_rope() { - for (const auto& name : LINEAR_FOR_ROPE) { - if (quantize_type_ == "") { - if (!absl::EndsWith(name, "weight")) { - continue; - } - } - int index = WEIGHT_MAPPING_W8A8.at(name); - at_weight_tensors_[index] = - view_tensor(at_weight_tensors_[index], name, true); - at_weight_tensors_[index] = trans_rope_weight(at_weight_tensors_[index]); - at_weight_tensors_[index] = - (!absl::EndsWith(name, "weight")) - ? view_tensor(at_weight_tensors_[index], name, false).flatten() - : view_tensor(at_weight_tensors_[index], name, false); - } -} - -torch::Tensor DeepseekV2DecoderLayerImpl::view_tensor(torch::Tensor weight, - const std::string& name, - bool pre_view) { - if (absl::StrContains(name, "q_b_proj")) { - if (pre_view) { - return weight - .view({prefill_param_.numAttentionHeadsPerRank, - qk_nope_head_dim_ + prefill_param_.qkRopeHeadDim, - -1}) - .contiguous(); - } else { - return weight - .view({prefill_param_.numAttentionHeadsPerRank * - (qk_nope_head_dim_ + prefill_param_.qkRopeHeadDim), - -1}) - .contiguous(); - } - } else if (absl::StrContains(name, "kv_a_proj_with_mqa")) { - return weight.view({kv_lora_rank_ + prefill_param_.qkRopeHeadDim, -1}) - .contiguous(); - } - return weight; -} - -torch::Tensor DeepseekV2DecoderLayerImpl::trans_rope_weight( - torch::Tensor weight) { - int64_t d = weight.size(-2); - int64_t rope_dim = prefill_param_.qkRopeHeadDim; - torch::Tensor weight_1 = - weight.slice(-2, d - rope_dim, torch::indexing::None, 2).contiguous(); - - torch::Tensor weight_2 = - weight.slice(-2, d - rope_dim + 1, torch::indexing::None, 2).contiguous(); - - torch::Tensor combined = torch::cat({weight_1, weight_2}, -2); - - weight.slice(-2, d - rope_dim, d).copy_(combined); - - return weight.contiguous(); -} - -torch::Tensor DeepseekV2DecoderLayerImpl::get_sharded_tensor( - const StateDict& state_dict, - const std::string& name, - int dim) { - if (parallel_args_.world_size() > 1) { - return state_dict.get_sharded_tensor( - name, dim, parallel_args_.rank(), parallel_args_.world_size()); - } else { - return state_dict.get_tensor(name); - } -} - -torch::Tensor DeepseekV2DecoderLayerImpl::get_sharded_tensor( - const StateDict& state_dict, - const std::string& name, - int dim, - int loacal_tp_rank, - int local_tp_size) { - if (local_tp_size > 1) { - return state_dict.get_sharded_tensor( - name, dim, loacal_tp_rank, local_tp_size); - } else { - return state_dict.get_tensor(name); - } -} - -std::string DeepseekV2DecoderLayerImpl::extract_endswith( - const std::string& input) { - std::vector parts; - std::stringstream ss(input); - std::string part; - while (std::getline(ss, part, '.')) { - parts.emplace_back(part); - } - if (parts.size() < 2) { - return ""; - } - std::string result = parts[parts.size() - 2] + "." + parts[parts.size() - 1]; - return result; -} - -int DeepseekV2DecoderLayerImpl::extract_expert_index(const std::string& name) { - std::string prefix = "experts."; - size_t pos = name.find(prefix); - if (pos != std::string::npos) { - pos += prefix.length(); - size_t end_pos = pos; - while (end_pos < name.length() && std::isdigit(name[end_pos])) { - ++end_pos; - } - if (end_pos > pos) { - return std::stoi(name.substr(pos, end_pos - pos)); - } - } - return -1; -} - -void DeepseekV2DecoderLayerImpl::verify_loaded_weights( - const std::string& prefix) const { - for (const auto& [index, name] : WEIGHT_MAPPING) { - CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) - << "weight is not loaded for " << prefix + name; - } -} - void DeepseekV2DecoderLayerImpl::merge_loaded_weights() { - if (quantize_type_ == "w8a8_dynamic") { - if (prefill_param_.isBF16) { - convert_descaled_weights_to_float(); - } - convert_offsets_to_int8(); - handle_device_specific_bias(); - } - - merge_shared_experts_weights(); - if (layer_id_ >= prefill_param_.firstKDenseReplace) { - merge_experts_weights(); - } - - squeeze_experts_weights(); - - preprocess_linear_for_rope(); - - at_weight_tensors_[IN_Q_PROJ_A_WEIGHT] = - torch::cat({at_weight_tensors_[IN_KV_PROJ_WITH_MQA_WEIGHT], - at_weight_tensors_[IN_Q_PROJ_A_WEIGHT]}, - 0) - .contiguous(); - if (quantize_type_ == "w8a8_dynamic") { - at_weight_tensors_[IN_Q_PROJ_A_BIAS] = - torch::cat({at_weight_tensors_[IN_KV_PROJ_WITH_MQA_BIAS], - at_weight_tensors_[IN_Q_PROJ_A_BIAS]}, - 0) - .contiguous(); - at_weight_tensors_[IN_Q_PROJ_A_DESCALE] = - torch::cat({at_weight_tensors_[IN_KV_PROJ_WITH_MQA_DESCALE], - at_weight_tensors_[IN_Q_PROJ_A_DESCALE]}, - 0) - .contiguous(); - } - - at_weight_tensors_[IN_Q_PROJ_A_WEIGHT] = at_npu::native::npu_format_cast( - at_weight_tensors_[IN_Q_PROJ_A_WEIGHT], 29); - at_weight_tensors_[IN_Q_PROJ_B_WEIGHT] = at_npu::native::npu_format_cast( - at_weight_tensors_[IN_Q_PROJ_B_WEIGHT], 29); - - at_weight_tensors_[IN_KV_PROJ_WITH_MQA_WEIGHT] = tensor_placeholder_; - at_weight_tensors_[IN_KV_PROJ_WITH_MQA_BIAS] = tensor_placeholder_; - at_weight_tensors_[IN_KV_PROJ_WITH_MQA_DESCALE] = tensor_placeholder_; - at_weight_tensors_[IN_KV_PROJ_WITH_MQA_OFFSET] = tensor_placeholder_; - at_weight_tensors_[IN_KV_PROJ_WITH_MQA_SCALE] = tensor_placeholder_; - if (FLAGS_expert_parallel_degree != 2) { - at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = - torch::roll(at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT], - {-1 * ep_rank_ * num_experts_per_partition_}, - {0}) - .contiguous(); - at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_BIAS] = - torch::roll(at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_BIAS], - {-1 * ep_rank_ * num_experts_per_partition_}, - {0}) - .contiguous(); - } - // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_SHARED_EXPERT] = - // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_SHARED_EXPERT].transpose(0, 1); - at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = - at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT].to(torch::kFloat32); - if (quantize_type_ == "w8a8_dynamic") { - // at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT] = - // at_weight_tensors_[IN_BLOCK_SPARSE_MOE_GATE_WEIGHT].to(torch::kFloat32); - if (!prefill_param_.isBF16) { - at_weight_tensors_[IN_Q_PROJ_A_DESCALE] = - convert_fp16_to_int64(at_weight_tensors_[IN_Q_PROJ_A_DESCALE]); - at_weight_tensors_[IN_Q_PROJ_B_DESCALE] = - convert_fp16_to_int64(at_weight_tensors_[IN_Q_PROJ_B_DESCALE]); - at_weight_tensors_[IN_ATTENTION_OUT_DESCALE] = - convert_fp16_to_int64(at_weight_tensors_[IN_ATTENTION_OUT_DESCALE]); - - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT].to( - torch::kFloat16); - at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT].to( - torch::kFloat32); - at_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT].to( - torch::kFloat32); - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT].to(torch::kFloat16); - at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT].to(torch::kFloat32); - at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = - at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT].to(torch::kFloat16); - at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = - at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT].to(torch::kFloat32); - } - } + loader_->merge_loaded_weights(); + auto& at_weight_tensors = loader_->get_at_weight_tensors(); c10_npu::NPUCachingAllocator::emptyCache(); for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[i]); } init_layer(); } -torch::Tensor DeepseekV2DecoderLayerImpl::convert_fp16_to_int64( - const torch::Tensor& fp16_tensor) { - auto float_tensor = fp16_tensor.to(torch::kFloat32); - auto int32_tensor = float_tensor.view(torch::kInt32); - auto int64_tensor = int32_tensor.to(torch::kInt64); - return int64_tensor; -} - -void DeepseekV2DecoderLayerImpl::convert_descaled_weights_to_float() { - auto convert_to_float = [this](int index) { - at_weight_tensors_[index] = at_weight_tensors_[index].to(torch::kFloat32); - }; - convert_to_float(IN_Q_PROJ_A_DESCALE); - convert_to_float(IN_Q_PROJ_B_DESCALE); - convert_to_float(IN_KV_PROJ_WITH_MQA_DESCALE); - convert_to_float(IN_ATTENTION_OUT_DESCALE); -} - -void DeepseekV2DecoderLayerImpl::convert_offsets_to_int8() { - auto convert_to_int8 = [this](int index) { - at_weight_tensors_[index] = - at_weight_tensors_[index].to(torch::kInt8).to(device_); - }; - convert_to_int8(IN_Q_PROJ_A_OFFSET); - convert_to_int8(IN_Q_PROJ_B_OFFSET); - convert_to_int8(IN_KV_PROJ_WITH_MQA_OFFSET); - convert_to_int8(IN_ATTENTION_OUT_OFFSET); -} - -void DeepseekV2DecoderLayerImpl::handle_device_specific_bias() { - if (dp_local_tp_rank_ != 0) { - torch::Tensor original_tensor = at_weight_tensors_[IN_ATTENTION_OUT_BIAS]; - at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = - torch::zeros(original_tensor.sizes(), - torch::TensorOptions() - .dtype(original_tensor.dtype()) - .device(original_tensor.device())); - } -} - -void DeepseekV2DecoderLayerImpl::merge_shared_experts_weights() { - auto merge_and_clear = [this](int index, - torch::Tensor& shared_experts_gate, - torch::Tensor& shared_experts_up) { - at_weight_tensors_[index] = - torch::cat({shared_experts_gate, shared_experts_up}, 0) - .to(device_) - .contiguous(); - shared_experts_gate = tensor_placeholder_; - shared_experts_up = tensor_placeholder_; - }; - - if (layer_id_ >= prefill_param_.firstKDenseReplace) { - merge_and_clear( - IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, - shared_experts_weights_["mlp.shared_experts.gate_proj.weight"], - shared_experts_weights_["mlp.shared_experts.up_proj.weight"]); - if (quantize_type_ == "w8a8_dynamic") { - merge_and_clear( - IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, - shared_experts_weights_["mlp.shared_experts.gate_proj.weight_offset"], - shared_experts_weights_["mlp.shared_experts.up_proj.weight_offset"]); - merge_and_clear( - IN_MLP_GATEUP_SCALE_SHARED_EXPERT, - shared_experts_weights_["mlp.shared_experts.gate_proj.weight_scale"], - shared_experts_weights_["mlp.shared_experts.up_proj.weight_scale"]); - } - } else { - merge_and_clear(IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, - shared_experts_weights_["mlp.gate_proj.weight"], - shared_experts_weights_["mlp.up_proj.weight"]); - if (quantize_type_ == "w8a8_dynamic") { - merge_and_clear(IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, - shared_experts_weights_["mlp.gate_proj.weight_offset"], - shared_experts_weights_["mlp.up_proj.weight_offset"]); - merge_and_clear(IN_MLP_GATEUP_SCALE_SHARED_EXPERT, - shared_experts_weights_["mlp.gate_proj.weight_scale"], - shared_experts_weights_["mlp.up_proj.weight_scale"]); - } - } -} - -void DeepseekV2DecoderLayerImpl::merge_experts_weights() { - torch::Tensor mlp_gateup_weight = - merge_experts_weights(experts_weights_["gate_proj.weight"], - experts_weights_["up_proj.weight"], - device_, - /*transpose=*/true); - at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_gateup_weight, 29); - // at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT] = - // at_npu::native::npu_format_cast(mlp_gateup_weight, 2).contiguous(); - if (quantize_type_ == "w8a8_dynamic") { - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = - merge_experts_weights(experts_weights_["gate_proj.weight_offset"], - experts_weights_["up_proj.weight_offset"], - device_); - at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = - merge_experts_weights(experts_weights_["gate_proj.weight_scale"], - experts_weights_["up_proj.weight_scale"], - device_); - } - -#if defined(USE_A3) - torch::Tensor mlp_down_weight = - merge_experts_weights(experts_weights_["down_proj.weight"], - device_, - /*transpose=*/false); - // at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - // at_npu::native::npu_format_cast(mlp_down_weight, 29); - at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); -#else - // TODO: xllm ops's GMM need to support MTP. - if (decode_param_.isBF16 && false) { - torch::Tensor mlp_down_weight = - merge_experts_weights(experts_weights_["down_proj.weight"], - device_, - /*transpose=*/true); - at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_down_weight, 29); - } else { - torch::Tensor mlp_down_weight = - merge_experts_weights(experts_weights_["down_proj.weight"], - device_, - /*transpose=*/false); - at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); - } -#endif - if (quantize_type_ == "w8a8_dynamic") { - at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = merge_experts_weights( - experts_weights_["down_proj.weight_offset"], device_); - at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = merge_experts_weights( - experts_weights_["down_proj.weight_scale"], device_); - } -} - -torch::Tensor DeepseekV2DecoderLayerImpl::merge_experts_weights( - std::vector& experts, - at::Device device, - bool transpose) { - torch::Tensor merged_tensor = torch::stack(experts, 0).to(device); - if (transpose) { - merged_tensor = merged_tensor.transpose(1, 2); - } - merged_tensor = merged_tensor.contiguous(); - experts.clear(); - return merged_tensor; -} +torch::Tensor DeepseekV2DecoderLayerImpl::build_expert_routing_map( + std::vector expert_lists) { + std::unordered_map> expert_routing_map; -torch::Tensor DeepseekV2DecoderLayerImpl::merge_experts_weights( - std::vector& experts_gate, - std::vector& experts_up, - at::Device device, - bool transpose) { - for (size_t i = 0; i < experts_up.size(); ++i) { - experts_gate[i] = torch::cat({experts_gate[i], experts_up[i]}, 0); + for (int64_t i = 0; i < expert_lists.size(); ++i) { + int64_t v = expert_lists[i]; + expert_routing_map[v].emplace_back(i); } - torch::Tensor merged_tensor = torch::stack(experts_gate, 0).to(device); + std::vector keys; + std::vector values; + for (auto& [key, indices] : expert_routing_map) { + int num_of_duplications = indices.size(); + int selected_index = ep_rank_ % num_of_duplications; + indices = {indices[selected_index]}; - if (transpose) { - merged_tensor = merged_tensor.transpose(1, 2); + keys.emplace_back(key); + values.emplace_back(static_cast(indices[0])); } - merged_tensor = merged_tensor.contiguous(); - experts_gate.clear(); - experts_up.clear(); - return merged_tensor; -} - -void DeepseekV2DecoderLayerImpl::merge_and_copy_gate_up_weights( - torch::Tensor& - target_buffer, // [num_experts, hidden_dim, gate_dim + up_dim] - const std::vector& experts_gate, // [gate_dim, hidden_dim] - const std::vector& experts_up, // [up_dim, hidden_dim] - bool do_transpose) { - const int64_t num_experts = experts_gate.size(); - const int64_t gate_dim = experts_gate[0].size(0); - const int64_t up_dim = experts_up[0].size(0); - const int64_t hidden_dim = experts_gate[0].size(1); - - target_buffer = at_npu::native::npu_format_cast(target_buffer.contiguous(), 2) - .reshape({num_experts, gate_dim + up_dim, hidden_dim}); - - for (int64_t index = 0; index < num_experts; ++index) { - target_buffer[index].slice(0, 0, gate_dim).copy_(experts_gate[index]); - - target_buffer[index] - .slice(0, gate_dim, gate_dim + up_dim) - .copy_(experts_up[index]); - } + int64_t map_size = expert_routing_map.size(); + auto options = torch::TensorOptions().dtype(torch::kInt32); + auto input = torch::zeros({map_size}, options); - if (do_transpose) { - target_buffer = target_buffer.transpose(1, 2).contiguous(); - ; - } + auto index_tensor = torch::tensor(keys, torch::kInt64); + auto value_tensor = torch::tensor(values, torch::kInt32); + auto result = input.scatter(0, index_tensor, value_tensor).to(device_); + // result = result.reshape({ep_size_,result.size(0)/ep_size_}).contiguous(); + return result; } -void DeepseekV2DecoderLayerImpl::merge_and_copy_down_weights( - torch::Tensor& target_buffer, - const std::vector& experts_down) { - const int64_t num_experts = experts_down.size(); - - for (int64_t index = 0; index < num_experts; ++index) { - target_buffer[index].copy_(experts_down[index]); - } +std::string DeepseekV2DecoderLayerImpl::get_expert_shm_key( + int32_t layer_id, + int32_t expert_index, + const std::string& suffix) { + std::string shm_key = + "layer_" + std::to_string(layer_id - first_k_dense_replace_) + "_" + + "expert_" + std::to_string(expert_index) + "_" + suffix; + return shm_key; } void DeepseekV2DecoderLayerImpl::prepare_expert_weight( const std::vector& expert_list) { + auto& at_weight_tensors = loader_->get_at_weight_tensors(); + auto& experts_weights = loader_->get_experts_weight_tensors(); expert_routing_map_buffer_ = build_expert_routing_map(expert_list); auto& expert_buffer = ExpertBuffer::Instance(); const int32_t num_local_experts = num_experts_per_partition_; const int32_t hidden_dim = - at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT].size(1); + at_weight_tensors[IN_MLP_GATEUP_WEIGHT_EXPERT].size(1); const int32_t combined_dim = - at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT].size(2); + at_weight_tensors[IN_MLP_GATEUP_WEIGHT_EXPERT].size(2); const int32_t gate_dim = combined_dim / 2; expert_buffer.initialize_or_reuse( @@ -1353,91 +572,100 @@ void DeepseekV2DecoderLayerImpl::prepare_expert_weight( combined_dim}, /*gateup_offset_shape*/ {num_experts_per_partition_, combined_dim, 1}, /*gateup_scale_shape*/ {num_experts_per_partition_, combined_dim, 1}, - /*down_weight_shape*/ {num_experts_per_partition_, hidden_dim, gate_dim}, + /*down_weight_shape*/ + {num_experts_per_partition_, hidden_dim, gate_dim}, /*down_offset_shape*/ {num_experts_per_partition_, hidden_dim, 1}, /*down_scale_shape*/ {num_experts_per_partition_, hidden_dim, 1}, - at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT].options(), - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT].options(), - at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT].options() + at_weight_tensors[IN_MLP_GATEUP_WEIGHT_EXPERT].options(), + at_weight_tensors[IN_MLP_GATEUP_OFFSET_EXPERT].options(), + at_weight_tensors[IN_MLP_GATEUP_SCALE_EXPERT].options() ); const int start_expert_idx = num_experts_per_partition_ * ep_rank_; const int end_expert_idx = start_expert_idx + num_experts_per_partition_ - 1; - for (const auto& pair : experts_weights_) { + auto& shared_buffer = loader_->get_expert_shared_buffer(); + for (const auto& pair : experts_weights) { for (int expert_idx = start_expert_idx; expert_idx <= end_expert_idx; ++expert_idx) { std::string shm_key = get_expert_shm_key(layer_id_, expert_list[expert_idx], pair.first); - experts_weights_[pair.first][expert_idx - start_expert_idx] = - shared_buffer_->get_tensor(expert_list[expert_idx], - layer_id_ - first_k_dense_replace_, - shm_key); + experts_weights[pair.first][expert_idx - start_expert_idx] = + shared_buffer->get_tensor(expert_list[expert_idx], + layer_id_ - first_k_dense_replace_, + shm_key); // experts_weights_[pair.first][expert_idx] = // shared_buffer_->get_tensors(shm_key); } } merge_and_copy_gate_up_weights(expert_buffer.gateup_weight, - experts_weights_["gate_proj.weight"], - experts_weights_["up_proj.weight"], + experts_weights["gate_proj.weight"], + experts_weights["up_proj.weight"], /*do_transpose=*/true); merge_and_copy_gate_up_weights(expert_buffer.gateup_offset, - experts_weights_["gate_proj.weight_offset"], - experts_weights_["up_proj.weight_offset"]); + experts_weights["gate_proj.weight_offset"], + experts_weights["up_proj.weight_offset"]); merge_and_copy_gate_up_weights(expert_buffer.gateup_scale, - experts_weights_["gate_proj.weight_scale"], - experts_weights_["up_proj.weight_scale"]); + experts_weights["gate_proj.weight_scale"], + experts_weights["up_proj.weight_scale"]); merge_and_copy_down_weights(expert_buffer.down_weight, - experts_weights_["down_proj.weight"]); + experts_weights["down_proj.weight"]); merge_and_copy_down_weights(expert_buffer.down_offset, - experts_weights_["down_proj.weight_offset"]); + experts_weights["down_proj.weight_offset"]); merge_and_copy_down_weights(expert_buffer.down_scale, - experts_weights_["down_proj.weight_scale"]); + experts_weights["down_proj.weight_scale"]); expert_buffer.gateup_weight = at_npu::native::npu_format_cast(expert_buffer.gateup_weight, 29); } -torch::Tensor DeepseekV2DecoderLayerImpl::build_expert_routing_map( - std::vector expert_lists) { - std::unordered_map> expert_routing_map; +void DeepseekV2DecoderLayerImpl::merge_and_copy_gate_up_weights( + torch::Tensor& + target_buffer, // [num_experts, hidden_dim, gate_dim + up_dim] + const std::vector& experts_gate, // [gate_dim, hidden_dim] + const std::vector& experts_up, // [up_dim, hidden_dim] + bool do_transpose) { + const int64_t num_experts = experts_gate.size(); + const int64_t gate_dim = experts_gate[0].size(0); + const int64_t up_dim = experts_up[0].size(0); + const int64_t hidden_dim = experts_gate[0].size(1); - for (int64_t i = 0; i < expert_lists.size(); ++i) { - int64_t v = expert_lists[i]; - expert_routing_map[v].emplace_back(i); - } + target_buffer = at_npu::native::npu_format_cast(target_buffer.contiguous(), 2) + .reshape({num_experts, gate_dim + up_dim, hidden_dim}); - std::vector keys; - std::vector values; - for (auto& [key, indices] : expert_routing_map) { - int num_of_duplications = indices.size(); - int selected_index = ep_rank_ % num_of_duplications; - indices = {indices[selected_index]}; + for (int64_t index = 0; index < num_experts; ++index) { + target_buffer[index].slice(0, 0, gate_dim).copy_(experts_gate[index]); - keys.emplace_back(key); - values.emplace_back(static_cast(indices[0])); + target_buffer[index] + .slice(0, gate_dim, gate_dim + up_dim) + .copy_(experts_up[index]); } - int64_t map_size = expert_routing_map.size(); - auto options = torch::TensorOptions().dtype(torch::kInt32); - auto input = torch::zeros({map_size}, options); + if (do_transpose) { + target_buffer = target_buffer.transpose(1, 2).contiguous(); + } +} - auto index_tensor = torch::tensor(keys, torch::kInt64); - auto value_tensor = torch::tensor(values, torch::kInt32); - auto result = input.scatter(0, index_tensor, value_tensor).to(device_); - // result = result.reshape({ep_size_,result.size(0)/ep_size_}).contiguous(); - return result; +void DeepseekV2DecoderLayerImpl::merge_and_copy_down_weights( + torch::Tensor& target_buffer, + const std::vector& experts_down) { + const int64_t num_experts = experts_down.size(); + + for (int64_t index = 0; index < num_experts; ++index) { + target_buffer[index].copy_(experts_down[index]); + } } void DeepseekV2DecoderLayerImpl::update_expert_weight() { auto& expert_buffer = ExpertBuffer::Instance(); + auto& at_weight_tensors = loader_->get_at_weight_tensors(); const auto tensor_pairs = { std::make_pair(IN_MLP_GATEUP_WEIGHT_EXPERT, std::ref(expert_buffer.gateup_weight)), @@ -1452,9 +680,9 @@ void DeepseekV2DecoderLayerImpl::update_expert_weight() { std::make_pair(IN_MLP_DOWN_SCALE_EXPERT, std::ref(expert_buffer.down_scale))}; for (auto& [index, buffer_tensor] : tensor_pairs) { - std::swap(at_weight_tensors_[index], buffer_tensor); + std::swap(at_weight_tensors[index], buffer_tensor); atb_weight_tensors_[index] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[index]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[index]); prefill_node_.inTensors.at(index) = &atb_weight_tensors_[index]; decode_node_.inTensors.at(index) = &atb_weight_tensors_[index]; decode_mla_node_.inTensors.at(index) = &atb_weight_tensors_[index]; @@ -1464,14 +692,6 @@ void DeepseekV2DecoderLayerImpl::update_expert_weight() { expert_routing_map_ = expert_routing_map_.contiguous(); } -void DeepseekV2DecoderLayerImpl::squeeze_experts_weights() { - for (const auto& index : SQUEEZE_WEIGHT_VEC) { - if (at_weight_tensors_[index].dim() > 1) { - at_weight_tensors_[index] = at_weight_tensors_[index].squeeze(); - } - } -} - int64_t DeepseekV2DecoderLayerImpl::init_layer() { name_ = "deepseek_v2_decoder_layer " + std::to_string(layer_id_); model_name_ = "DeepSeek_V2"; diff --git a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h index c3599878c..012ef9c00 100644 --- a/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_deepseek_v2_decoder_layer_impl.h @@ -27,6 +27,7 @@ limitations under the License. #include "framework/model/npu_dp_ep_padding.h" #include "framework/model_context.h" #include "framework/state_dict/state_dict.h" +#include "loader/deepseek_v2_decoder_loader.h" #include "npu_base_layer.h" #include "xllm_kernels/models/deepseekv2/layer/decoder_layer.h" @@ -110,12 +111,10 @@ class DeepseekV2DecoderLayerImpl : public BaseLayer { ~DeepseekV2DecoderLayerImpl() {}; - virtual void load_state_dict(const StateDict& state_dict) override; - - void verify_loaded_weights(const std::string& prefix) const; - virtual void merge_loaded_weights() override; + torch::Tensor build_expert_routing_map(std::vector expert_lists); + void prepare_expert_weight(const std::vector& expert_list); void update_expert_weight(); @@ -139,9 +138,11 @@ class DeepseekV2DecoderLayerImpl : public BaseLayer { bool use_dp_sharding = false; }; - void initialize_tensors(const torch::TensorOptions& options); + std::string get_expert_shm_key(int32_t layer_id, + int32_t expert_index, + const std::string& suffix); - void initialize_weight_tensors(const torch::TensorOptions& options); + void initialize_tensors(const torch::TensorOptions& options); void param_from_args(atb_speed::deepseekV2::DecoderLayerParam& param, const ModelArgs& args, @@ -149,7 +150,9 @@ class DeepseekV2DecoderLayerImpl : public BaseLayer { bool is_prefill); void reserve_experts_weights(int num_of_device_experts); + void initialize_device_expert_list(int numdevice, int num_layers); + void initialize_basic_parameters( atb_speed::deepseekV2::DecoderLayerParam& param, const ModelArgs& args, @@ -178,79 +181,6 @@ class DeepseekV2DecoderLayerImpl : public BaseLayer { const ModelArgs& args, bool is_prefill); - torch::Tensor get_sharded_tensor(const StateDict& state_dict, - const std::string& name, - int dim); - - torch::Tensor get_sharded_tensor(const StateDict& state_dict, - const std::string& name, - int dim, - int local_tp_rank, - int local_tp_size); - - std::string extract_endswith(const std::string& input); - - std::string get_expert_shm_key(int32_t layer_id, - int32_t expert_ids, - const std::string& suffix); - torch::Tensor build_expert_routing_map(std::vector expert_lists); - void set_kv_weight(const StateDict& state_dict, - const std::string& tensor_name, - int weight_position, - int dim); - - int extract_expert_index(const std::string& name); - - void convert_descaled_weights_to_float(); - - void convert_offsets_to_int8(); - - torch::Tensor convert_fp16_to_int64(const torch::Tensor& fp16_tensor); - - void handle_device_specific_bias(); - - void merge_shared_experts_weights(); - - void merge_experts_weights(); - - void squeeze_experts_weights(); - - void preprocess_linear_for_rope(); - - void process_expert_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - void process_shared_expert_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - void process_mlp_common_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - void process_general_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - int get_mapped_index(const std::string& name, - const std::unordered_map& mapping); - - torch::Tensor view_tensor(torch::Tensor weight, - const std::string& name, - bool pre_view); - - torch::Tensor trans_rope_weight(torch::Tensor weight); - - torch::Tensor merge_experts_weights(std::vector& experts, - at::Device device, - bool transpose = false); - - torch::Tensor merge_experts_weights(std::vector& experts_up, - std::vector& experts_gate, - at::Device device, - bool transpose = false); - void merge_and_copy_gate_up_weights( torch::Tensor& target_buffer, const std::vector& experts_gate, @@ -328,17 +258,7 @@ class DeepseekV2DecoderLayerImpl : public BaseLayer { torch::Tensor at_in_device_expert_count_; std::vector int_placeholder_; - std::vector device_expert_list_; - - std::unordered_map shared_experts_weights_; - std::unordered_map> experts_weights_; - std::unordered_map> - all_experts_weights_buffer_; - - std::mutex shared_experts_mutex_; - std::mutex experts_mutex_; - std::unique_ptr shared_buffer_ = nullptr; torch::Tensor expert_routing_map_; torch::Tensor expert_routing_map_buffer_; }; diff --git a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp index ef07e0958..edab794c3 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp +++ b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp @@ -24,253 +24,8 @@ DECLARE_int32(expert_parallel_degree); namespace xllm { namespace layer { -enum DecoderLayerTensorId : int { - IN_INPUT_NORM_WEIGHT = 0, - IN_INPUT_NORM_BIAS = 1, - IN_INPUT_NORM_NEW_WEIGHT = 2, - IN_INPUT_NORM_NEW_BIAS = 3, - - IN_QKV_WEIGHT_0 = 4, - IN_QKV_BIAS_0 = 5, - IN_QKV_DESCALE_0 = 6, - IN_QKV_OFFSET_0 = 7, - IN_QKV_SCALE_0 = 8, - IN_QKV_COMPRESS_IDX_0 = 9, - - IN_QKV_WEIGHT_1 = 10, - IN_QKV_BIAS_1 = 11, - IN_QKV_DESCALE_1 = 12, - IN_QKV_OFFSET_1 = 13, - IN_QKV_SCALE_1 = 14, - IN_QKV_COMPRESS_IDX_1 = 15, - - IN_QKV_WEIGHT_2 = 16, - IN_QKV_BIAS_2 = 17, - IN_QKV_DESCALE_2 = 18, - IN_QKV_OFFSET_2 = 19, - IN_QKV_SCALE_2 = 20, - IN_QKV_COMPRESS_IDX_2 = 21, - - IN_QKV_DENSE_WEIGHT = 22, - IN_QKV_DENSE_BIAS = 23, - IN_QKV_DENSE_DESCALE = 24, - IN_QKV_DENSE_OFFSET = 25, - IN_QKV_DENSE_SCALE = 26, - IN_QKV_DENSE_COMPRESS_IDX = 27, - - IN_POST_ATTN_NORM_WEIGHT = 28, - IN_POST_ATTN_NORM_BIAS = 29, - IN_POST_ATTN_NORM_NEW_WEIGHT = 30, - IN_POST_ATTN_NORM_NEW_BIAS = 31, - - IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT = 32, - IN_MLP_GATEUP_BIAS_SHARED_EXPERT = 33, - IN_MLP_GATEUP_DESCALE_SHARED_EXPERT = 34, - IN_MLP_GATEUP_OFFSET_SHARED_EXPERT = 35, - IN_MLP_GATEUP_SCALE_SHARED_EXPERT = 36, - IN_MLP_GATEUP_COMPRESS_IDX_SHARED_EXPERT = 37, - - IN_MLP_DOWN_WEIGHT_SHARED_EXPERT = 38, - IN_MLP_DOWN_BIAS_SHARED_EXPERT = 39, - IN_MLP_DOWN_DESCALE_SHARED_EXPERT = 40, - IN_MLP_DOWN_OFFSET_SHARED_EXPERT = 41, - IN_MLP_DOWN_SCALE_SHARED_EXPERT = 42, - IN_MLP_DOWN_COMPRESS_IDX_SHARED_EXPERT = 43, - - IN_SHARED_EXPERT_GATE_WEIGHT = 44, - IN_SHARED_EXPERT_GATE_BIAS = 45, - IN_SHARED_EXPERT_GATE_DESCALE = 46, - IN_SHARED_EXPERT_GATE_OFFSET = 47, - IN_SHARED_EXPERT_GATE_SCALE = 48, - IN_SHARED_EXPERT_GATE_COMPRESS_IDX = 49, - - BLOCK_SPARSE_MOE_GATE_WEIGHT = 50, - BLOCK_SPARSE_MOE_GATE_BIAS = 51, - BLOCK_SPARSE_MOE_GATE_DESCALE = 52, - BLOCK_SPARSE_MOE_GATE_OFFSET = 53, - BLOCK_SPARSE_MOE_GATE_SCALE = 54, - BLOCK_SPARSE_MOE_GATE_COMPRESS_IDX = 55, - - IN_MLP_GATEUP_WEIGHT = 56, - IN_MLP_GATEUP_BIAS = 57, - IN_MLP_GATEUP_DESCALE = 58, - IN_MLP_GATEUP_OFFSET = 59, - IN_MLP_GATEUP_SCALE = 60, - IN_MLP_GATEUP_COMPRESS_IDX = 61, - - IN_MLP_DOWN_WEIGHT = 62, - IN_MLP_DOWN_BIAS = 63, - IN_MLP_DOWN_DESCALE = 64, - IN_MLP_DOWN_OFFSET = 65, - IN_MLP_DOWN_SCALE = 66, - IN_MLP_DOWN_COMPRESS_IDX = 67, - - Q_NORM_WEIGHT = 68, - K_NORM_WEIGHT = 69 -}; - static uint64_t WEIGHT_COUNT_PER_LAYER = 68; -static std::unordered_map WEIGHT_MAPPING = { - {"input_layernorm.weight", IN_INPUT_NORM_WEIGHT}, - - {"self_attn.q_proj.weight", IN_QKV_WEIGHT_0}, - {"self_attn.q_proj.bias", IN_QKV_BIAS_0}, - - {"self_attn.k_proj.weight", IN_QKV_WEIGHT_1}, - {"self_attn.k_proj.bias", IN_QKV_BIAS_1}, - - {"self_attn.v_proj.weight", IN_QKV_WEIGHT_2}, - {"self_attn.v_proj.bias", IN_QKV_BIAS_2}, - - {"self_attn.o_proj.weight", IN_QKV_DENSE_WEIGHT}, - - {"post_attention_layernorm.weight", IN_POST_ATTN_NORM_WEIGHT}, - - // mlp or shared expert - {"mlp.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - - {"mlp.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - - {"mlp.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, - - {"mlp.shared_experts.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - - {"mlp.shared_experts.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - - {"mlp.shared_experts.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, - - // MoE Gate - {"mlp.gate.weight", BLOCK_SPARSE_MOE_GATE_WEIGHT}, - {"mlp.gate.e_score_correction_bias", BLOCK_SPARSE_MOE_GATE_BIAS}, - - // Expert MLP - Gate/Up projections - {"gate_proj.weight", IN_MLP_GATEUP_WEIGHT}, - {"up_proj.weight", IN_MLP_GATEUP_WEIGHT}, - - // Expert MLP - Down projection - {"down_proj.weight", IN_MLP_DOWN_WEIGHT}, - -}; - -static std::unordered_map WEIGHT_MAPPING_W8A8 = { - {"input_layernorm.weight", IN_INPUT_NORM_WEIGHT}, - {"input_layernorm.bias", IN_INPUT_NORM_NEW_BIAS}, - - {"self_attn.q_proj.weight", IN_QKV_WEIGHT_0}, - {"self_attn.q_proj.deq_scale", IN_QKV_DESCALE_0}, - {"self_attn.q_proj.quant_bias", IN_QKV_BIAS_0}, - {"self_attn.q_proj.input_offset", IN_QKV_OFFSET_0}, - {"self_attn.q_proj.input_scale", IN_QKV_SCALE_0}, - - {"self_attn.k_proj.weight", IN_QKV_WEIGHT_1}, - {"self_attn.k_proj.deq_scale", IN_QKV_DESCALE_1}, - {"self_attn.k_proj.quant_bias", IN_QKV_BIAS_1}, - - {"self_attn.v_proj.weight", IN_QKV_WEIGHT_2}, - {"self_attn.v_proj.deq_scale", IN_QKV_DESCALE_2}, - {"self_attn.v_proj.quant_bias", IN_QKV_BIAS_2}, - - {"self_attn.o_proj.weight", IN_QKV_DENSE_WEIGHT}, - {"self_attn.o_proj.quant_bias", IN_QKV_DENSE_BIAS}, - {"self_attn.o_proj.deq_scale", IN_QKV_DENSE_DESCALE}, - {"self_attn.o_proj.weight_offset", IN_QKV_DENSE_OFFSET}, - {"self_attn.o_proj.weight_scale", IN_QKV_DENSE_SCALE}, - - {"post_attention_layernorm.weight", IN_POST_ATTN_NORM_WEIGHT}, - {"post_attention_layernorm.bias", IN_POST_ATTN_NORM_NEW_BIAS}, - - // mlp - {"mlp.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - {"mlp.gate_proj.weight_offset", IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, - {"mlp.gate_proj.weight_scale", IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, - - {"mlp.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - {"mlp.up_proj.weight_offset", IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, - {"mlp.up_proj.weight_scale", IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, - - {"mlp.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, - {"mlp.down_proj.weight_offset", IN_MLP_DOWN_OFFSET_SHARED_EXPERT}, - {"mlp.down_proj.weight_scale", IN_MLP_DOWN_SCALE_SHARED_EXPERT}, - - // shared expert - {"mlp.shared_experts.gate_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - {"mlp.shared_experts.gate_proj.weight_offset", - IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, - {"mlp.shared_experts.gate_proj.weight_scale", - IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, - - {"mlp.shared_experts.up_proj.weight", IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT}, - {"mlp.shared_experts.up_proj.weight_offset", - IN_MLP_GATEUP_OFFSET_SHARED_EXPERT}, - {"mlp.shared_experts.up_proj.weight_scale", - IN_MLP_GATEUP_SCALE_SHARED_EXPERT}, - - {"mlp.shared_experts.down_proj.weight", IN_MLP_DOWN_WEIGHT_SHARED_EXPERT}, - {"mlp.shared_experts.down_proj.weight_offset", - IN_MLP_DOWN_OFFSET_SHARED_EXPERT}, - {"mlp.shared_experts.down_proj.weight_scale", - IN_MLP_DOWN_SCALE_SHARED_EXPERT}, - - // MoE Gate - {"mlp.gate.weight", BLOCK_SPARSE_MOE_GATE_WEIGHT}, - {"mlp.gate.e_score_correction_bias", BLOCK_SPARSE_MOE_GATE_BIAS}, - - {"gate_proj.weight", IN_MLP_GATEUP_WEIGHT}, - {"gate_proj.weight_offset", IN_MLP_GATEUP_OFFSET}, - {"gate_proj.weight_scale", IN_MLP_GATEUP_SCALE}, - {"up_proj.weight", IN_MLP_GATEUP_WEIGHT}, - {"up_proj.weight_offset", IN_MLP_GATEUP_OFFSET}, - {"up_proj.weight_scale", IN_MLP_GATEUP_SCALE}, - - {"down_proj.weight", IN_MLP_DOWN_WEIGHT}, - {"down_proj.weight_offset", IN_MLP_DOWN_OFFSET}, - {"down_proj.weight_scale", IN_MLP_DOWN_SCALE}, -}; - -static const std::unordered_map> - SPECIAL_MULTI_ASSIGN_W8A8 = { - {"input_layernorm.weight", - {IN_INPUT_NORM_WEIGHT, IN_INPUT_NORM_NEW_WEIGHT}}, - {"post_attention_layernorm.weight", - {IN_POST_ATTN_NORM_WEIGHT, IN_POST_ATTN_NORM_NEW_WEIGHT}}, -}; - -static const std::map WEIGHT_SHARD = { - {IN_QKV_WEIGHT_0, 0}, - {IN_QKV_BIAS_0, 0}, - {IN_QKV_WEIGHT_1, 0}, - {IN_QKV_BIAS_1, 0}, - {IN_QKV_WEIGHT_2, 0}, - {IN_QKV_BIAS_2, 0}, - {IN_QKV_DENSE_WEIGHT, 1}, - {IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, 0}, - {IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, 1}, - {IN_MLP_GATEUP_WEIGHT, 0}, - {IN_MLP_DOWN_WEIGHT, 1}, -}; - -static const std::map WEIGHT_SHARD_W8A8 = { - {IN_QKV_WEIGHT_0, 0}, - {IN_QKV_BIAS_0, 0}, - {IN_QKV_DESCALE_0, 0}, - {IN_QKV_WEIGHT_1, 0}, - {IN_QKV_BIAS_1, 0}, - {IN_QKV_DESCALE_1, 0}, - {IN_QKV_WEIGHT_2, 0}, - {IN_QKV_BIAS_2, 0}, - {IN_QKV_DESCALE_2, 0}, - {IN_QKV_DENSE_WEIGHT, 1}, - {IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, 0}, - {IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, 0}, - {IN_MLP_GATEUP_SCALE_SHARED_EXPERT, 0}, - {IN_MLP_DOWN_WEIGHT_SHARED_EXPERT, 1}, - {IN_MLP_GATEUP_WEIGHT, 0}, - {IN_MLP_GATEUP_OFFSET, 0}, - {IN_MLP_GATEUP_SCALE, 0}, - {IN_MLP_DOWN_WEIGHT, 1}, -}; - Glm4MoeDecoderImpl::Glm4MoeDecoderImpl(const ModelContext& context, const int32_t layer_id) : BaseLayer(context), @@ -282,7 +37,6 @@ Glm4MoeDecoderImpl::Glm4MoeDecoderImpl(const ModelContext& context, auto parallel_args = context.get_parallel_args(); auto options = context.get_tensor_options(); - num_experts_ = model_args.num_experts(); ep_size_ = parallel_args.ep_size(); ep_local_tp_size_ = parallel_args.world_size() / ep_size_; CHECK_EQ(parallel_args.world_size(), ep_size_ * ep_local_tp_size_); @@ -292,15 +46,17 @@ Glm4MoeDecoderImpl::Glm4MoeDecoderImpl(const ModelContext& context, start_expert_id_ = ep_rank_ * num_experts_per_partition_; end_expert_id_ = start_expert_id_ + num_experts_per_partition_ - 1; - dp_size_ = parallel_args.dp_size(); - dp_local_tp_size_ = parallel_args.world_size() / dp_size_; - CHECK_EQ(parallel_args.world_size(), dp_size_ * dp_local_tp_size_); - dp_local_tp_rank_ = parallel_args.rank() % dp_local_tp_size_; - - n_kv_heads_ = static_cast(model_args.n_kv_heads().value()); - param_from_args(prefill_param_, model_args, parallel_args, true); param_from_args(decode_param_, model_args, parallel_args, false); + atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + placeholder_vec_ = {1}; + device_id_ = options.device().index(); + + loader_ = + std::make_unique(WEIGHT_COUNT_PER_LAYER, + context, + layer_id_, + prefill_param_.firstKDenseReplace); initialize_tensors(options); } @@ -308,7 +64,7 @@ Glm4MoeDecoderImpl::Glm4MoeDecoderImpl(const ModelContext& context, void Glm4MoeDecoderImpl::initialize_tensors( const torch::TensorOptions& options) { // initializ placeholder - at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; int_tensor_placeholder_ = torch::ones({1}).to(torch::kInt32).to(device_); @@ -316,7 +72,7 @@ void Glm4MoeDecoderImpl::initialize_tensors( block_tables_placeholder_ = torch::zeros({1, 1}).to(torch::kInt32).to(device_); tensor_placeholder_ = torch::zeros({1}).to(options); - resize_experts_weights(num_experts_per_partition_); + loader_->resize_experts_weights(num_experts_per_partition_); expert_group_ = torch::arange(1024, torch::kInt32).to(device_); one_hot_ = torch::tensor({1}, torch::kInt32).to(device_); zero_hot_ = torch::tensor({0}, torch::kInt32).to(device_); @@ -339,33 +95,11 @@ void Glm4MoeDecoderImpl::param_from_args(atb_speed::moe::MoeLayerParam& param, initialize_quantization_parameters(param); } -void Glm4MoeDecoderImpl::resize_experts_weights(int num_of_device_experts) { - experts_weights_["gate_proj.weight"] = - std::vector(num_of_device_experts); - experts_weights_["up_proj.weight"] = - std::vector(num_of_device_experts); - experts_weights_["down_proj.weight"] = - std::vector(num_of_device_experts); - if (quantize_type_.compare("w8a8_dynamic") == 0) { - experts_weights_["gate_proj.weight_offset"] = - std::vector(num_of_device_experts); - experts_weights_["up_proj.weight_offset"] = - std::vector(num_of_device_experts); - experts_weights_["down_proj.weight_offset"] = - std::vector(num_of_device_experts); - experts_weights_["gate_proj.weight_scale"] = - std::vector(num_of_device_experts); - experts_weights_["up_proj.weight_scale"] = - std::vector(num_of_device_experts); - experts_weights_["down_proj.weight_scale"] = - std::vector(num_of_device_experts); - } -} - void Glm4MoeDecoderImpl::initialize_weight_tensors( const torch::TensorOptions& options) { + auto& at_weight_tensors = loader_->get_at_weight_tensors(); for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - at_weight_tensors_[i] = torch::zeros({1}).to(options); + at_weight_tensors[i] = torch::zeros({1}).to(options); } } @@ -413,10 +147,6 @@ void Glm4MoeDecoderImpl::initialize_basic_parameters( param.useQKNorm = args.use_qk_norm(); if (args.use_qk_norm()) { WEIGHT_COUNT_PER_LAYER = 70; - WEIGHT_MAPPING_W8A8["self_attn.q_norm.weight"] = Q_NORM_WEIGHT; - WEIGHT_MAPPING_W8A8["self_attn.k_norm.weight"] = K_NORM_WEIGHT; - WEIGHT_MAPPING["self_attn.q_norm.weight"] = Q_NORM_WEIGHT; - WEIGHT_MAPPING["self_attn.k_norm.weight"] = K_NORM_WEIGHT; } param.hiddenSizePerAttentionHead = args.head_dim(); std::optional optionalValue = args.n_kv_heads(); @@ -554,490 +284,17 @@ void Glm4MoeDecoderImpl::initialize_quantization_parameters( } } -void Glm4MoeDecoderImpl::load_state_dict(const StateDict& state_dict) { - for (const auto& [name, tensor] : state_dict) { - bool is_sharded = false; - int index = 0; - - if (absl::StartsWith(name, "mlp.experts")) { - process_expert_weights(state_dict, name, tensor); - continue; - } - if (absl::StartsWith(name, "mlp.shared_experts")) { - process_shared_expert_weights(state_dict, name, tensor); - continue; - } - if (absl::StartsWith(name, "mlp") && !absl::StrContains(name, "gate.")) { - process_mlp_common_weights(state_dict, name, tensor); - continue; - } - - process_general_weights(state_dict, name, tensor); - } -} - -int Glm4MoeDecoderImpl::get_mapped_index( - const std::string& name, - const std::unordered_map& mapping) { - const auto it = mapping.find(name); - if (it == mapping.end()) { - LOG(ERROR) << "Missing mapping for: " << name; - return -1; - } - - return it->second; -} - -void Glm4MoeDecoderImpl::process_expert_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - int expert_index = extract_expert_index(name); - if (expert_index < start_expert_id_ || expert_index > end_expert_id_) { - return; - } - - const std::string suffix = extract_endswith(name); - const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_MAPPING_W8A8 - : WEIGHT_MAPPING; - const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_SHARD_W8A8 - : WEIGHT_SHARD; - const int index = get_mapped_index(suffix, weight_mapping); - const int local_index = expert_index % num_experts_per_partition_; - const bool is_sharded = shard_map.count(index); - - std::lock_guard lock(experts_mutex_); - torch::Tensor tmp_tensor = is_sharded - ? get_sharded_tensor(state_dict, - name, - shard_map.at(index), - ep_local_tp_rank_, - ep_local_tp_size_) - : tensor; - - experts_weights_[suffix][local_index] = tmp_tensor.clone(); -} - -void Glm4MoeDecoderImpl::process_shared_expert_weights( - const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - torch::Tensor tmp_tensor; - const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_MAPPING_W8A8 - : WEIGHT_MAPPING; - const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_SHARD_W8A8 - : WEIGHT_SHARD; - std::lock_guard lock(shared_experts_mutex_); - const int index = get_mapped_index(name, weight_mapping); - if (index == -1) { - return; - } - - const bool is_sharded = shard_map.count(index); - tmp_tensor = is_sharded - ? get_sharded_tensor(state_dict, name, shard_map.at(index)) - .to(device_) - : tensor.to(device_); - - if (absl::StrContains(name, "down_proj")) { - at_weight_tensors_[index] = tmp_tensor; - } else { - shared_experts_weights_[name] = tmp_tensor; - } -} - -void Glm4MoeDecoderImpl::process_mlp_common_weights( - const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_MAPPING_W8A8 - : WEIGHT_MAPPING; - const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_SHARD_W8A8 - : WEIGHT_SHARD; - const int index = get_mapped_index(name, weight_mapping); - const bool is_sharded = shard_map.count(index); - - std::lock_guard lock(shared_experts_mutex_); - - torch::Tensor tmp_tensor = is_sharded - ? get_sharded_tensor(state_dict, - name, - shard_map.at(index), - dp_local_tp_rank_, - dp_local_tp_size_) - .to(device_) - : tensor.to(device_); - if (absl::StrContains(name, "down_proj")) { - at_weight_tensors_[index] = tmp_tensor; - } else { - shared_experts_weights_[name] = tmp_tensor; - } -} - -void Glm4MoeDecoderImpl::process_general_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_MAPPING_W8A8 - : WEIGHT_MAPPING; - const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_SHARD_W8A8 - : WEIGHT_SHARD; - - if (weight_mapping.find(name) == weight_mapping.end()) { - return; - } - - const int index = get_mapped_index(name, weight_mapping); - const bool is_sharded = shard_map.count(index); - torch::Tensor tmp_tensor; - int32_t tp_rank = dp_local_tp_rank_; - int32_t tp_size = dp_local_tp_size_; - if (index == IN_QKV_WEIGHT_1 || index == IN_QKV_WEIGHT_2 || - index == IN_QKV_BIAS_1 || index == IN_QKV_BIAS_2 || - index == IN_QKV_DESCALE_1 || index == IN_QKV_DESCALE_2) { - if (n_kv_heads_ < dp_local_tp_size_) { - int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_); - tp_rank = tp_rank / repeat_times; - tp_size = n_kv_heads_; - } - } - if (is_sharded) { - tmp_tensor = get_sharded_tensor( - state_dict, name, shard_map.at(index), tp_rank, tp_size) - .to(device_); - } else { - tmp_tensor = tensor.to(device_); - } - if (index == BLOCK_SPARSE_MOE_GATE_BIAS) { - auto min_val = tmp_tensor.min(); - tmp_tensor = tmp_tensor - min_val; - } - correct_tensor_dtype(tmp_tensor, name); - if (quantize_type_.compare("w8a8_dynamic") == 0) { - auto it = SPECIAL_MULTI_ASSIGN_W8A8.find(name); - if (it != SPECIAL_MULTI_ASSIGN_W8A8.end()) { - for (int idx : it->second) { - at_weight_tensors_[idx] = tmp_tensor; - } - return; - } - } - at_weight_tensors_[index] = tmp_tensor; -} - -torch::Tensor Glm4MoeDecoderImpl::get_sharded_tensor( - const StateDict& state_dict, - const std::string& name, - int dim) { - if (parallel_args_.world_size() > 1) { - return state_dict.get_sharded_tensor( - name, dim, parallel_args_.rank(), parallel_args_.world_size()); - } else { - return state_dict.get_tensor(name); - } -} - -torch::Tensor Glm4MoeDecoderImpl::get_sharded_tensor( - const StateDict& state_dict, - const std::string& name, - int dim, - int loacal_tp_rank, - int local_tp_size) { - if (local_tp_size > 1) { - return state_dict.get_sharded_tensor( - name, dim, loacal_tp_rank, local_tp_size); - } else { - return state_dict.get_tensor(name); - } -} - -std::string Glm4MoeDecoderImpl::extract_endswith(const std::string& input) { - std::vector parts; - std::stringstream ss(input); - std::string part; - while (std::getline(ss, part, '.')) { - parts.push_back(part); - } - if (parts.size() < 2) { - return ""; - } - std::string result = parts[parts.size() - 2] + "." + parts[parts.size() - 1]; - - return result; -} - -int Glm4MoeDecoderImpl::extract_expert_index(const std::string& name) { - std::string prefix = "experts."; - size_t pos = name.find(prefix); - if (pos != std::string::npos) { - pos += prefix.length(); - size_t end_pos = pos; - while (end_pos < name.length() && std::isdigit(name[end_pos])) { - ++end_pos; - } - if (end_pos > pos) { - return std::stoi(name.substr(pos, end_pos - pos)); - } - } - - return -1; -} - -void Glm4MoeDecoderImpl::verify_loaded_weights( - const std::string& prefix) const { - for (const auto& [name, index] : WEIGHT_MAPPING) { - if (name == "down_proj.weight" || name == "gate_proj.weight" || - name == "up_proj.weight" || name == "mlp.gate.weight" || - name == "mlp.gate.e_score_correction_bias") { - continue; - } - CHECK(at_weight_tensors_[index].sizes() != std::vector({0})) - << layer_id_ << "-weight is not loaded for " << name; - } -} - void Glm4MoeDecoderImpl::merge_loaded_weights() { - merge_shared_experts_weights(); - if (layer_id_ >= prefill_param_.firstKDenseReplace) { - merge_experts_weights(); - } - at_weight_tensors_[IN_QKV_WEIGHT_0] = - torch::cat({at_weight_tensors_[IN_QKV_WEIGHT_0], - at_weight_tensors_[IN_QKV_WEIGHT_1], - at_weight_tensors_[IN_QKV_WEIGHT_2]}, - 0) - .contiguous(); - at_weight_tensors_[IN_QKV_WEIGHT_1] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_WEIGHT_2] = - torch::zeros({1}, torch::kFloat16).to(device_); - - at_weight_tensors_[IN_QKV_BIAS_0] = - at_weight_tensors_[IN_QKV_BIAS_0].squeeze(); - at_weight_tensors_[IN_QKV_BIAS_1] = - at_weight_tensors_[IN_QKV_BIAS_1].squeeze(); - at_weight_tensors_[IN_QKV_BIAS_2] = - at_weight_tensors_[IN_QKV_BIAS_2].squeeze(); - - at_weight_tensors_[IN_QKV_BIAS_0] = - torch::cat({at_weight_tensors_[IN_QKV_BIAS_0], - at_weight_tensors_[IN_QKV_BIAS_1], - at_weight_tensors_[IN_QKV_BIAS_2]}, - 0) - .contiguous(); - at_weight_tensors_[IN_QKV_BIAS_1] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_BIAS_2] = - torch::zeros({1}, torch::kFloat16).to(device_); - - if (quantize_type_.compare("w8a8_dynamic") == 0) { - at_weight_tensors_[IN_QKV_DESCALE_0] = - at_weight_tensors_[IN_QKV_DESCALE_0].squeeze(); - at_weight_tensors_[IN_QKV_DESCALE_1] = - at_weight_tensors_[IN_QKV_DESCALE_1].squeeze(); - at_weight_tensors_[IN_QKV_DESCALE_2] = - at_weight_tensors_[IN_QKV_DESCALE_2].squeeze(); - - at_weight_tensors_[IN_QKV_DESCALE_0] = - torch::cat({at_weight_tensors_[IN_QKV_DESCALE_0], - at_weight_tensors_[IN_QKV_DESCALE_1], - at_weight_tensors_[IN_QKV_DESCALE_2]}, - 0) - .contiguous(); - - at_weight_tensors_[IN_QKV_DESCALE_1] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_DESCALE_2] = - torch::zeros({1}, torch::kFloat16).to(device_); - - at_weight_tensors_[IN_QKV_DENSE_BIAS] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_DENSE_DESCALE] = - torch::zeros({1}, torch::kFloat16).to(device_); - - at_weight_tensors_[IN_QKV_OFFSET_0] = - at_weight_tensors_[IN_QKV_OFFSET_0].to(torch::kInt8).to(device_); - at_weight_tensors_[IN_QKV_OFFSET_1] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_OFFSET_2] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_DENSE_OFFSET] = - at_weight_tensors_[IN_QKV_DENSE_OFFSET].contiguous().view(-1); - - at_weight_tensors_[IN_QKV_SCALE_1] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_SCALE_2] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_DENSE_SCALE] = - at_weight_tensors_[IN_QKV_DENSE_SCALE].contiguous().view(-1); - } + loader_->merge_loaded_weights(); + auto& at_weight_tensors = loader_->get_at_weight_tensors(); c10_npu::NPUCachingAllocator::emptyCache(); for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[i]); } init_layer(); } -torch::Tensor Glm4MoeDecoderImpl::convert_fp16_to_int64( - const torch::Tensor& fp16_tensor) { - auto float_tensor = fp16_tensor.to(torch::kFloat32); - auto int32_tensor = float_tensor.view(torch::kInt32); - auto int64_tensor = int32_tensor.to(torch::kInt64); - return int64_tensor; -} - -void Glm4MoeDecoderImpl::convert_descaled_weights_to_float() { - auto convert_to_float = [this](int index) { - at_weight_tensors_[index] = at_weight_tensors_[index].to(torch::kFloat32); - }; - convert_to_float(IN_QKV_DENSE_DESCALE); -} -void Glm4MoeDecoderImpl::merge_shared_experts_weights() { - auto merge_and_clear = [this](int index, - torch::Tensor& shared_experts_gate, - torch::Tensor& shared_experts_up) { - at_weight_tensors_[index] = - torch::cat({shared_experts_gate, shared_experts_up}, 0) - .to(device_) - .contiguous(); - shared_experts_gate = tensor_placeholder_; - shared_experts_up = tensor_placeholder_; - }; - - if (layer_id_ >= prefill_param_.firstKDenseReplace) { - merge_and_clear( - IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, - shared_experts_weights_["mlp.shared_experts.gate_proj.weight"], - shared_experts_weights_["mlp.shared_experts.up_proj.weight"]); - if (quantize_type_ == "w8a8_dynamic") { - merge_and_clear( - IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, - shared_experts_weights_["mlp.shared_experts.gate_proj.weight_offset"], - shared_experts_weights_["mlp.shared_experts.up_proj.weight_offset"]); - merge_and_clear( - IN_MLP_GATEUP_SCALE_SHARED_EXPERT, - shared_experts_weights_["mlp.shared_experts.gate_proj.weight_scale"], - shared_experts_weights_["mlp.shared_experts.up_proj.weight_scale"]); - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT].squeeze(); - at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT].squeeze(); - at_weight_tensors_[IN_MLP_DOWN_OFFSET_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_DOWN_OFFSET_SHARED_EXPERT].squeeze(); - at_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_DOWN_SCALE_SHARED_EXPERT].squeeze(); - } - } else { - merge_and_clear(IN_MLP_GATEUP_WEIGHT_SHARED_EXPERT, - shared_experts_weights_["mlp.gate_proj.weight"], - shared_experts_weights_["mlp.up_proj.weight"]); - if (quantize_type_ == "w8a8_dynamic") { - merge_and_clear(IN_MLP_GATEUP_OFFSET_SHARED_EXPERT, - shared_experts_weights_["mlp.gate_proj.weight_offset"], - shared_experts_weights_["mlp.up_proj.weight_offset"]); - merge_and_clear(IN_MLP_GATEUP_SCALE_SHARED_EXPERT, - shared_experts_weights_["mlp.gate_proj.weight_scale"], - shared_experts_weights_["mlp.up_proj.weight_scale"]); - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_SHARED_EXPERT].squeeze(); - at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT] = - at_weight_tensors_[IN_MLP_GATEUP_SCALE_SHARED_EXPERT].squeeze(); - } - } -} - -void Glm4MoeDecoderImpl::merge_experts_weights() { - try { - torch::Tensor mlp_gateup_weight; - if (quantize_type_.compare("w8a8_dynamic") == 0) { - mlp_gateup_weight = - merge_experts_weights(experts_weights_["gate_proj.weight"], - experts_weights_["up_proj.weight"], - /*transpose=*/true); - - at_weight_tensors_[IN_MLP_GATEUP_OFFSET] = - merge_experts_weights(experts_weights_["gate_proj.weight_offset"], - experts_weights_["up_proj.weight_offset"]); - at_weight_tensors_[IN_MLP_GATEUP_SCALE] = - merge_experts_weights(experts_weights_["gate_proj.weight_scale"], - experts_weights_["up_proj.weight_scale"]); - at_weight_tensors_[IN_MLP_GATEUP_WEIGHT] = - at_npu::native::npu_format_cast(mlp_gateup_weight, 29); - } else { - mlp_gateup_weight = - merge_experts_weights(experts_weights_["gate_proj.weight"], - experts_weights_["up_proj.weight"], - /*transpose=*/false); - at_weight_tensors_[IN_MLP_GATEUP_WEIGHT] = - at_npu::native::npu_format_cast(mlp_gateup_weight, 2).contiguous(); - } - } catch (const std::exception& e) { - LOG(ERROR) << "[ERROR] Exception in gateup weight processing: " << e.what(); - throw; - } - - if (experts_weights_.count("down_proj.weight") > 0) { - auto& down_weight = experts_weights_["down_proj.weight"]; - } - - try { - torch::Tensor mlp_down_weight = - merge_experts_weights(experts_weights_["down_proj.weight"], - /*transpose=*/false); - - at_weight_tensors_[IN_MLP_DOWN_WEIGHT] = - at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); - - if (quantize_type_.compare("w8a8_dynamic") == 0) { - at_weight_tensors_[IN_MLP_DOWN_OFFSET] = - merge_experts_weights(experts_weights_["down_proj.weight_offset"]); - at_weight_tensors_[IN_MLP_DOWN_SCALE] = - merge_experts_weights(experts_weights_["down_proj.weight_scale"]); - } - } catch (const std::exception& e) { - LOG(ERROR) << "[ERROR] Exception in down weight processing: " << e.what(); - throw; - } -} - -torch::Tensor Glm4MoeDecoderImpl::merge_experts_weights( - std::vector& experts, - bool transpose) { - torch::Tensor merged_tensor = torch::stack(experts, 0).to(device_); - if (transpose) { - merged_tensor = merged_tensor.transpose(1, 2); - } - merged_tensor = merged_tensor.contiguous(); - experts.clear(); - - return merged_tensor; -} - -torch::Tensor Glm4MoeDecoderImpl::merge_experts_weights( - std::vector& experts_gate, - std::vector& experts_up, - bool transpose) { - for (size_t i = 0; i < experts_up.size(); ++i) { - experts_gate[i] = torch::cat({experts_gate[i], experts_up[i]}, 0); - } - torch::Tensor merged_tensor = torch::stack(experts_gate, 0).to(device_); - if (transpose) { - merged_tensor = merged_tensor.transpose(1, 2); - } - merged_tensor = merged_tensor.contiguous(); - experts_gate.clear(); - experts_up.clear(); - - return merged_tensor; -} - int64_t Glm4MoeDecoderImpl::init_layer() { BaseLayer::name_ = "glm4_moe_decoder_layer " + std::to_string(layer_id_); model_name_ = "Glm4_Moe"; diff --git a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h index 74c8602b1..7daf174fd 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h +++ b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h @@ -25,6 +25,7 @@ limitations under the License. #include "framework/model/npu_dp_ep_padding.h" #include "framework/quant_args.h" #include "framework/state_dict/state_dict.h" +#include "loader/glm4_moe_decoder_loader.h" #include "npu_base_layer.h" #include "xllm_kernels/models/glm/layer/moe_decoder_layer.h" @@ -38,10 +39,6 @@ class Glm4MoeDecoderImpl : public BaseLayer { ~Glm4MoeDecoderImpl() {}; - void load_state_dict(const StateDict& state_dict); - - void verify_loaded_weights(const std::string& prefix) const; - void merge_loaded_weights(); torch::Tensor block_tables_placeholder_; @@ -73,8 +70,6 @@ class Glm4MoeDecoderImpl : public BaseLayer { const ParallelArgs& parallel_args, bool is_prefill); - void resize_experts_weights(int num_of_device_experts); - void initialize_basic_parameters(atb_speed::moe::MoeLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, @@ -93,68 +88,6 @@ class Glm4MoeDecoderImpl : public BaseLayer { void initialize_quantization_parameters(atb_speed::moe::MoeLayerParam& param); - torch::Tensor get_sharded_tensor(const StateDict& state_dict, - const std::string& name, - int dim); - torch::Tensor get_sharded_tensor(const StateDict& state_dict, - const std::string& name, - int dim, - int local_tp_rank, - int local_tp_size); - - std::string extract_endswith(const std::string& input); - - void set_kv_weight(const StateDict& state_dict, - const std::string& tensor_name, - int weight_position, - int dim); - - int extract_expert_index(const std::string& name); - - void convert_descaled_weights_to_float(); - - torch::Tensor convert_fp16_to_int64(const torch::Tensor& fp16_tensor); - - void merge_shared_experts_weights(); - - void merge_experts_weights(); - - void squeeze_experts_weights(); - - void preprocess_linear_for_rope(); - - void process_expert_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - void process_shared_expert_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - void process_mlp_common_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - void process_general_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - int get_mapped_index(const std::string& name, - const std::unordered_map& mapping); - - torch::Tensor view_tensor(torch::Tensor weight, - const std::string& name, - bool pre_view); - - torch::Tensor trans_rope_weight(torch::Tensor weight); - - torch::Tensor merge_experts_weights(std::vector& experts, - bool transpose = false); - - torch::Tensor merge_experts_weights(std::vector& experts_up, - std::vector& experts_gate, - bool transpose = false); - int64_t init_layer(); int64_t init_node(atb_speed::Model::Node& node, @@ -176,19 +109,12 @@ class Glm4MoeDecoderImpl : public BaseLayer { int32_t layer_id_; int32_t ep_size_; - int32_t num_experts_; int32_t num_experts_per_partition_; int32_t ep_local_tp_size_; int32_t ep_local_tp_rank_; int32_t start_expert_id_; int32_t end_expert_id_; int32_t ep_rank_; - int32_t n_kv_heads_; - - int32_t dp_size_; - int32_t dp_local_tp_size_; - int32_t dp_rank_; - int32_t dp_local_tp_rank_; int32_t num_speculative_tokens_ = 0; atb_speed::moe::MoeLayerParam prefill_param_; @@ -200,6 +126,7 @@ class Glm4MoeDecoderImpl : public BaseLayer { atb::Tensor internal_tensor_; torch::Tensor tensor_placeholder_; + torch::Tensor slot_tensor_placeholder_; torch::Tensor int_tensor_placeholder_; torch::Tensor decode_attn_mask_; @@ -209,14 +136,6 @@ class Glm4MoeDecoderImpl : public BaseLayer { torch::Tensor final_hidden_states_; torch::Tensor at_start_expert_id_; torch::Tensor at_in_device_expert_count_; - - std::vector int_placeholder_; - - std::unordered_map shared_experts_weights_; - std::unordered_map> experts_weights_; - - std::mutex shared_experts_mutex_; - std::mutex experts_mutex_; }; class Glm4MoeDecoder : public torch::nn::ModuleHolder { @@ -238,4 +157,4 @@ std::vector get_dtp_inputs(torch::Tensor token_size_per_dp_group, int32_t rank, at::Device device); } // namespace layer -} // namespace xllm +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp index 1f7acba0c..2d3c1a83b 100644 --- a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp @@ -22,6 +22,7 @@ limitations under the License. #include "common/global_flags.h" #include "core/layers/common/attention_mask_impl.h" +#include "loader/llama_decoder_loader.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/core/npu/NPUException.h" @@ -30,88 +31,6 @@ namespace layer { const uint64_t WEIGHT_COUNT_PER_LAYER = 50; -enum DecoderLayerTensorId : int { - - IN_NORM_WEIGHT = 0, // weight - IN_NORM_BIAS, // bias - IN_NORM_NEW_WEIGHT, // new weight - IN_NORM_NEW_BIAS, // new bias - - IN_Q_WEIGHT, // weight - IN_Q_BIAS, // bias - IN_Q_DEQSCALE, // deq_scale - IN_Q_OFFSET, // offset - IN_Q_SCALE, // scale - IN_Q_COMPRESS_IDX, - - IN_K_WEIGHT, // weight - IN_K_BIAS, // bias - IN_K_DEQSCALE, // deq_scale - IN_K_OFFSET, // offset - IN_K_SCALE, // scale - IN_K_COMPRESS_IDX, - - IN_V_WEIGHT, // weight - IN_V_BIAS, // bias - IN_V_DEQSCALE, // deq_scale - IN_V_OFFSET, // offset - IN_V_SCALE, // scale - IN_V_COMPRESS_IDX, - - IN_ATTENTION_OUT_WEIGHT, // weight - IN_ATTENTION_OUT_BIAS, // bias - IN_ATTENTION_OUT_DEQSCALE, // deq_scale - IN_ATTENTION_OUT_OFFSET, // offset - IN_ATTENTION_OUT_SCALE, // scale - IN_ATTENTION_OUT_COMPRESS_IDX, - - IN_SELFOUT_NORM_WEIGHT, // weight - IN_SELFOUT_NORM_BIAS, // bias - IN_SELFOUT_NORM_NEW_WEIGHT, // new weight - IN_SELFOUT_NORM_NEW_BIAS, // new bias - - IN_MLP_W2_WEIGHT, // weight - IN_MLP_W2_BIAS, // bias - IN_MLP_W2_DEQSCALE, // deq_scale - IN_MLP_W2_OFFSET, // offset - IN_MLP_W2_SCALE, // scale - IN_MLP_W2_COMPRESS_IDX, - - IN_MLP_W1_WEIGHT, // weight - IN_MLP_W1_BIAS, // bias - IN_MLP_W1_DEQSCALE, // deq_scale - IN_MLP_W1_OFFSET, // offset - IN_MLP_W1_SCALE, // scale - IN_MLP_W1_COMPRESS_IDX, - - IN_MLP_CPROJ_WEIGHT, // weight - IN_MLP_CPROJ_BIAS, // bias - IN_MLP_CPROJ_DEQSCALE, // deq_scale - IN_MLP_CPROJ_OFFSET, // offset - IN_MLP_CPROJ_SCALE, // scale - IN_MLP_CPROJ_COMPRESS_IDX, -}; - -static const std::unordered_map WEIGHT_MAPPING = { - {"input_layernorm.weight", IN_NORM_WEIGHT}, - {"self_attn.q_proj.weight", IN_Q_WEIGHT}, - {"self_attn.k_proj.weight", IN_K_WEIGHT}, - {"self_attn.v_proj.weight", IN_V_WEIGHT}, - {"self_attn.o_proj.weight", IN_ATTENTION_OUT_WEIGHT}, - {"post_attention_layernorm.weight", IN_SELFOUT_NORM_WEIGHT}, - {"mlp.gate_proj.weight", IN_MLP_W2_WEIGHT}, - {"mlp.up_proj.weight", IN_MLP_W1_WEIGHT}, - {"mlp.down_proj.weight", IN_MLP_CPROJ_WEIGHT}, -}; - -static std::map WEIGHT_SHARD = {{IN_Q_WEIGHT, 0}, - {IN_K_WEIGHT, 0}, - {IN_V_WEIGHT, 0}, - {IN_ATTENTION_OUT_WEIGHT, 1}, - {IN_MLP_W2_WEIGHT, 0}, - {IN_MLP_W1_WEIGHT, 0}, - {IN_MLP_CPROJ_WEIGHT, 1}}; - LlamaDecoderLayerImpl::LlamaDecoderLayerImpl(const ModelContext& context) : BaseLayer(context) { param_from_args(prefill_param_, @@ -123,7 +42,6 @@ LlamaDecoderLayerImpl::LlamaDecoderLayerImpl(const ModelContext& context) context.get_parallel_args(), false); - at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; @@ -132,10 +50,10 @@ LlamaDecoderLayerImpl::LlamaDecoderLayerImpl(const ModelContext& context) device_id_ = options.device().index(); placeholder_ = atb_speed::Utils::AtTensor2Tensor( torch::zeros({1}).to(device_).to(dtype_)); + + loader_ = + std::make_unique(WEIGHT_COUNT_PER_LAYER, context); at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); - for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - at_weight_tensors_[i] = torch::zeros({1}).to(options); - } } // fix param @@ -175,49 +93,19 @@ void LlamaDecoderLayerImpl::param_from_args( // param.enableLogN = false; } -void LlamaDecoderLayerImpl::verify_loaded_weights() const { - for (const auto& [name, index] : WEIGHT_MAPPING) { - CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) - << "weight is not loaded for " << name; - } -} - void LlamaDecoderLayerImpl::merge_loaded_weights() { - auto new_q_weight = torch::cat({at_weight_tensors_[IN_Q_WEIGHT], - at_weight_tensors_[IN_K_WEIGHT], - at_weight_tensors_[IN_V_WEIGHT]}, - 0); - at_weight_tensors_[IN_Q_WEIGHT] = new_q_weight; - - at_weight_tensors_[IN_K_WEIGHT] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_WEIGHT] = torch::zeros({1}).to(device_); - - auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], - at_weight_tensors_[IN_MLP_W1_WEIGHT]}, - 0); - at_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight; - - at_weight_tensors_[IN_MLP_W1_WEIGHT] = torch::zeros({1}).to(device_); + loader_->merge_loaded_weights(); + auto& at_weight_tensors = loader_->get_at_weight_tensors(); c10_npu::NPUCachingAllocator::emptyCache(); for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[i]); } init_layer(); } -void LlamaDecoderLayerImpl::load_state_dict(const StateDict& state_dict) { - for (const auto& [name, index] : WEIGHT_MAPPING) { - if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD[index]); - } else { - set_weight(state_dict, name, index); - } - } -} - int64_t LlamaDecoderLayerImpl::init_layer() { init_attn_mask(); name_ = "llama_decoder_layer"; diff --git a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.h b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.h index 55fd0285a..672d9bffd 100644 --- a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.h @@ -49,10 +49,6 @@ class LlamaDecoderLayerImpl : public BaseLayer { ~LlamaDecoderLayerImpl() {}; - virtual void load_state_dict(const StateDict& state_dict) override; - - virtual void verify_loaded_weights() const override; - virtual void merge_loaded_weights() override; virtual int64_t init_layer() override; diff --git a/xllm/core/layers/npu/npu_lm_head_impl.cpp b/xllm/core/layers/npu/npu_lm_head_impl.cpp index 6eb78af4a..bdcddb3c6 100644 --- a/xllm/core/layers/npu/npu_lm_head_impl.cpp +++ b/xllm/core/layers/npu/npu_lm_head_impl.cpp @@ -79,43 +79,27 @@ LmHeadImpl::LmHeadImpl(const ModelContext& context) : BaseLayer(context) { context.get_parallel_args(), false); - at_weight_tensors_.resize(1); atb_weight_tensors_.resize(1); atOutTensors_.resize(1); auto options = context.get_tensor_options(); dtype_ = c10::typeMetaToScalarType(options.dtype()); - at_weight_tensors_[0] = torch::zeros({1}).to(options); prefill_tensor_storage_.resize(2); decode_tensor_storage_.resize(2); torch_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); placeholder_ = atb_speed::Utils::AtTensor2Tensor(torch_placeholder_); -} -void LmHeadImpl::verify_loaded_weights(const std::string weight_str) const { - // std::cout<({1})) - << "final lm_head weight is not loaded for " << weight_str; + loader_ = std::make_unique(1, context); } void LmHeadImpl::merge_loaded_weights() { + auto& at_weight_tensors = loader_->get_at_weight_tensors(); atb_weight_tensors_[0] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[0]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[0]); init_layer(); } -void LmHeadImpl::load_state_dict(const StateDict& state_dict) { - // set_weight(state_dict, "weight", 0, 0); - if (dp_size_ > 1) { - set_weight( - state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_); - } else { - set_weight(state_dict, "weight", 0, 0); - } -} - int64_t LmHeadImpl::init_layer() { BaseLayer::name_ = "lm_head_layer"; model_name_ = "lm"; diff --git a/xllm/core/layers/npu/npu_lm_head_impl.h b/xllm/core/layers/npu/npu_lm_head_impl.h index 085b6bdb5..ecd8a4669 100644 --- a/xllm/core/layers/npu/npu_lm_head_impl.h +++ b/xllm/core/layers/npu/npu_lm_head_impl.h @@ -29,6 +29,7 @@ limitations under the License. #include "atb/atb_infer.h" #include "framework/model/model_input_params.h" #include "framework/model_context.h" +#include "loader/lm_head_loader.h" #include "nlohmann/json.hpp" #include "npu_base_layer.h" #include "pytorch/adapter/utils/utils.h" @@ -47,10 +48,6 @@ class LmHeadImpl : public BaseLayer { ~LmHeadImpl() {}; - void load_state_dict(const StateDict& state_dict) override; - - void verify_loaded_weights(const std::string weight_str) const; - void merge_loaded_weights() override; torch::Tensor forward(const torch::Tensor& hidden_states, diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp index 5150692a5..095e15d81 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp @@ -32,84 +32,6 @@ namespace layer { const uint64_t WEIGHT_COUNT_PER_LAYER = 50; -static std::vector> WEIGHT_MAPPING = { - {IN_NORM_WEIGHT, "input_layernorm.weight"}, - {IN_Q_WEIGHT, "self_attn.q_proj.weight"}, - {IN_Q_BIAS, "self_attn.q_proj.bias"}, - {IN_K_WEIGHT, "self_attn.k_proj.weight"}, - {IN_K_BIAS, "self_attn.k_proj.bias"}, - {IN_V_WEIGHT, "self_attn.v_proj.weight"}, - {IN_V_BIAS, "self_attn.v_proj.bias"}, - {IN_ATTENTION_OUT_WEIGHT, "self_attn.o_proj.weight"}, - {IN_SELFOUT_NORM_WEIGHT, "post_attention_layernorm.weight"}, - {IN_MLP_W2_WEIGHT, "mlp.gate_proj.weight"}, - {IN_MLP_W1_WEIGHT, "mlp.up_proj.weight"}, - {IN_MLP_CPROJ_WEIGHT, "mlp.down_proj.weight"}}; - -static std::vector> WEIGHT_MAPPING_W8A8 = { - {IN_NORM_WEIGHT, "input_layernorm.weight"}, - {IN_Q_WEIGHT, "self_attn.q_proj.weight"}, - {IN_Q_BIAS, "self_attn.q_proj.quant_bias"}, - {IN_Q_DEQSCALE, "self_attn.q_proj.deq_scale"}, - {IN_Q_OFFSET, "self_attn.q_proj.input_offset"}, - {IN_Q_SCALE, "self_attn.q_proj.input_scale"}, - {IN_K_WEIGHT, "self_attn.k_proj.weight"}, - {IN_K_BIAS, "self_attn.k_proj.quant_bias"}, - {IN_K_DEQSCALE, "self_attn.k_proj.deq_scale"}, - {IN_K_OFFSET, "self_attn.k_proj.input_offset"}, - {IN_K_SCALE, "self_attn.k_proj.input_scale"}, - {IN_V_WEIGHT, "self_attn.v_proj.weight"}, - {IN_V_BIAS, "self_attn.v_proj.quant_bias"}, - {IN_V_DEQSCALE, "self_attn.v_proj.deq_scale"}, - {IN_V_OFFSET, "self_attn.v_proj.input_offset"}, - {IN_V_SCALE, "self_attn.v_proj.input_scale"}, - {IN_ATTENTION_OUT_WEIGHT, "self_attn.o_proj.weight"}, - {IN_ATTENTION_OUT_BIAS, "self_attn.o_proj.quant_bias"}, - {IN_ATTENTION_OUT_DEQSCALE, "self_attn.o_proj.deq_scale"}, - {IN_ATTENTION_OUT_OFFSET, "self_attn.o_proj.input_offset"}, - {IN_ATTENTION_OUT_SCALE, "self_attn.o_proj.input_scale"}, - {IN_SELFOUT_NORM_WEIGHT, "post_attention_layernorm.weight"}, - {IN_MLP_W2_WEIGHT, "mlp.gate_proj.weight"}, - {IN_MLP_W2_BIAS, "mlp.gate_proj.quant_bias"}, - {IN_MLP_W2_DEQSCALE, "mlp.gate_proj.deq_scale"}, - {IN_MLP_W2_OFFSET, "mlp.gate_proj.input_offset"}, - {IN_MLP_W2_SCALE, "mlp.gate_proj.input_scale"}, - {IN_MLP_W1_WEIGHT, "mlp.up_proj.weight"}, - {IN_MLP_W1_BIAS, "mlp.up_proj.quant_bias"}, - {IN_MLP_W1_DEQSCALE, "mlp.up_proj.deq_scale"}, - {IN_MLP_W1_OFFSET, "mlp.up_proj.input_offset"}, - {IN_MLP_W1_SCALE, "mlp.up_proj.input_scale"}, - {IN_MLP_CPROJ_WEIGHT, "mlp.down_proj.weight"}}; - -static std::map WEIGHT_SHARD = {{IN_Q_WEIGHT, 0}, - {IN_Q_BIAS, 0}, - {IN_K_WEIGHT, 0}, - {IN_K_BIAS, 0}, - {IN_V_WEIGHT, 0}, - {IN_V_BIAS, 0}, - {IN_ATTENTION_OUT_WEIGHT, 1}, - {IN_MLP_W2_WEIGHT, 0}, - {IN_MLP_W1_WEIGHT, 0}, - {IN_MLP_CPROJ_WEIGHT, 1}}; - -static std::map WEIGHT_SHARD_W8A8 = {{IN_Q_WEIGHT, 0}, - {IN_Q_BIAS, 0}, - {IN_Q_DEQSCALE, 0}, - {IN_K_WEIGHT, 0}, - {IN_K_BIAS, 0}, - {IN_K_DEQSCALE, 0}, - {IN_V_WEIGHT, 0}, - {IN_V_BIAS, 0}, - {IN_V_DEQSCALE, 0}, - {IN_ATTENTION_OUT_WEIGHT, 1}, - {IN_MLP_W2_WEIGHT, 0}, - {IN_MLP_W2_BIAS, 0}, - {IN_MLP_W2_DEQSCALE, 0}, - {IN_MLP_W1_WEIGHT, 0}, - {IN_MLP_W1_BIAS, 0}, - {IN_MLP_W1_DEQSCALE, 0}, - {IN_MLP_CPROJ_WEIGHT, 1}}; - void Qwen2DecoderLayerImpl::param_from_args( atb_speed::qwen::DecoderLayerParam& param, const ModelArgs& args, @@ -163,7 +85,6 @@ Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context) param_from_args(prefill_param_, model_args, parallel_args, true); param_from_args(decode_param_, model_args, parallel_args, false); - at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; dtype_ = c10::typeMetaToScalarType(options.dtype()); @@ -175,146 +96,35 @@ Qwen2DecoderLayerImpl::Qwen2DecoderLayerImpl(const ModelContext& context) placeholder_ = atb_speed::Utils::AtTensor2Tensor( torch::zeros({1}).to(device_).to(dtype_)); at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); - for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - at_weight_tensors_[i] = torch::zeros({1}).to(options); - } -} - -void Qwen2DecoderLayerImpl::verify_loaded_weights() const { - for (const auto& [index, name] : WEIGHT_MAPPING) { - CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) - << "weight is not loaded for " << name; - } + loader_ = + std::make_unique(WEIGHT_COUNT_PER_LAYER, context); + initialize_quantization_parameters(); } -TransposeType Qwen2DecoderLayerImpl::check_transpose(at::Tensor& tensor) { - bool is_k_divisible = tensor.size(1) % 256 == 0; - bool is_n_divisible = tensor.size(0) % 256 == 0; - - if (!is_k_divisible && is_n_divisible) { - return TransposeType::NOT_TRANSPOSE; - } - - return TransposeType::TRANSPOSE; -} - -void Qwen2DecoderLayerImpl::merge_loaded_weights() { - if (quantize_type_ == "w8a8") { - at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE] = - at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE].to(torch::kFloat32); - at_weight_tensors_[IN_Q_DEQSCALE] = - torch::cat({at_weight_tensors_[IN_Q_DEQSCALE], - at_weight_tensors_[IN_K_DEQSCALE], - at_weight_tensors_[IN_V_DEQSCALE]}, - 0) - .to(torch::kFloat32); - at_weight_tensors_[IN_K_DEQSCALE] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_DEQSCALE] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_K_OFFSET] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_OFFSET] = torch::zeros({1}).to(device_); - - at_weight_tensors_[IN_K_SCALE] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_SCALE] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_MLP_W2_BIAS] = - torch::cat({at_weight_tensors_[IN_MLP_W2_BIAS], - at_weight_tensors_[IN_MLP_W1_BIAS]}, - 0); - at_weight_tensors_[IN_MLP_W1_BIAS] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_MLP_W2_DEQSCALE] = - torch::cat({at_weight_tensors_[IN_MLP_W2_DEQSCALE], - at_weight_tensors_[IN_MLP_W1_DEQSCALE]}, - 0) - .to(torch::kFloat32); - at_weight_tensors_[IN_MLP_W1_DEQSCALE] = torch::zeros({1}).to(device_); - - at_weight_tensors_[IN_MLP_W1_OFFSET] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_MLP_W1_SCALE] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_Q_OFFSET] = - at_weight_tensors_[IN_Q_OFFSET].to(torch::kInt8).to(device_); - at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] = - at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] - .to(torch::kInt8) - .to(device_); - at_weight_tensors_[IN_MLP_W2_OFFSET] = - at_weight_tensors_[IN_MLP_W2_OFFSET].to(torch::kInt8).to(device_); - if (device_id_ != 0) { - torch::Tensor original_tensor = at_weight_tensors_[IN_ATTENTION_OUT_BIAS]; - auto shape = original_tensor.sizes(); - auto dtype = original_tensor.dtype(); - auto device = original_tensor.device(); - - at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = torch::zeros( - shape, torch::TensorOptions().dtype(dtype).device(device)); - } - } - - auto new_q_weight = torch::cat({at_weight_tensors_[IN_Q_WEIGHT], - at_weight_tensors_[IN_K_WEIGHT], - at_weight_tensors_[IN_V_WEIGHT]}, - 0); - - at_weight_tensors_[IN_Q_WEIGHT] = new_q_weight; - - at_weight_tensors_[IN_K_WEIGHT] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_WEIGHT] = torch::zeros({1}).to(device_); - - auto new_q_bias = torch::cat({at_weight_tensors_[IN_Q_BIAS], - at_weight_tensors_[IN_K_BIAS], - at_weight_tensors_[IN_V_BIAS]}, - 0); - at_weight_tensors_[IN_Q_BIAS] = new_q_bias; - - at_weight_tensors_[IN_K_BIAS] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_V_BIAS] = torch::zeros({1}).to(device_); - +void Qwen2DecoderLayerImpl::initialize_linear_transpose_type() { + auto& at_weight_tensors = loader_->get_at_weight_tensors(); TransposeType transpose_type = - check_transpose(at_weight_tensors_[IN_MLP_W2_WEIGHT]); + check_transpose(at_weight_tensors[IN_MLP_W2_WEIGHT]); int transpose_value = static_cast(transpose_type); prefill_param_.linearTransposeType[4] = transpose_value; decode_param_.linearTransposeType[4] = transpose_value; - if (transpose_type == TransposeType::TRANSPOSE) { - auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], - at_weight_tensors_[IN_MLP_W1_WEIGHT]}, - 0); - at_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight.contiguous(); - } else { - auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], - at_weight_tensors_[IN_MLP_W1_WEIGHT]}, - 0) - .transpose(0, 1); - at_weight_tensors_[IN_MLP_W2_WEIGHT] = new_mlp_weight.contiguous(); - } - - at_weight_tensors_[IN_MLP_W1_WEIGHT] = torch::zeros({1}).to(device_); +} +void Qwen2DecoderLayerImpl::merge_loaded_weights() { + initialize_linear_transpose_type(); + loader_->merge_loaded_weights(); + auto& at_weight_tensors = loader_->get_at_weight_tensors(); c10_npu::NPUCachingAllocator::emptyCache(); for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[i]); } init_layer(); } -void Qwen2DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { +void Qwen2DecoderLayerImpl::initialize_quantization_parameters() { if (quantize_type_ == "w8a8") { - for (const auto& [index, name] : WEIGHT_MAPPING_W8A8) { - if (WEIGHT_SHARD_W8A8.find(index) != WEIGHT_SHARD_W8A8.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD_W8A8[index]); - } else { - set_weight(state_dict, name, index); - } - } - at_weight_tensors_[IN_NORM_BIAS] = - torch::zeros(at_weight_tensors_[IN_NORM_WEIGHT].sizes(), - at_weight_tensors_[IN_NORM_WEIGHT].options()) - .to(device_); - - at_weight_tensors_[IN_SELFOUT_NORM_BIAS] = - torch::zeros(at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].sizes(), - at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].options()) - .to(device_); - prefill_param_.packQuantType = {static_cast(PackType::ALL_W8A8), static_cast(PackType::ALL_W8A8)}; decode_param_.packQuantType = {static_cast(PackType::ALL_W8A8), @@ -333,16 +143,18 @@ void Qwen2DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { static_cast(LinearType::INT), static_cast(LinearType::INVALID), static_cast(LinearType::FP)}; - return; } +} + +TransposeType Qwen2DecoderLayerImpl::check_transpose(at::Tensor& tensor) { + bool is_k_divisible = tensor.size(1) % 256 == 0; + bool is_n_divisible = tensor.size(0) % 256 == 0; - for (const auto& [index, name] : WEIGHT_MAPPING) { - if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD[index]); - } else { - set_weight(state_dict, name, index); - } + if (!is_k_divisible && is_n_divisible) { + return TransposeType::NOT_TRANSPOSE; } + + return TransposeType::TRANSPOSE; } int64_t Qwen2DecoderLayerImpl::init_layer() { diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h index f5fe63e6c..87ae4f285 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h @@ -31,6 +31,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "framework/model_context.h" #include "framework/state_dict/state_dict.h" +#include "loader/qwen2_decoder_loader.h" #include "nlohmann/json.hpp" #include "npu_base_layer.h" #include "pytorch/adapter/utils/utils.h" @@ -43,79 +44,12 @@ limitations under the License. namespace xllm { namespace layer { -enum DecoderLayerTensorId : int { - IN_NORM_WEIGHT = 0, // weight - IN_NORM_BIAS = 1, // bias - IN_NORM_NEW_WEIGHT = 2, // new weight - IN_NORM_NEW_BIAS = 3, // new bias - - IN_Q_WEIGHT = 4, // weight - IN_Q_BIAS = 5, // bias - IN_Q_DEQSCALE = 6, // deq_scale - IN_Q_OFFSET = 7, // offset - IN_Q_SCALE = 8, // scale - IN_Q_COMPRESS_IDX = 9, - - IN_K_WEIGHT = 10, // weight - IN_K_BIAS = 11, // bias - IN_K_DEQSCALE = 12, // deq_scale - IN_K_OFFSET = 13, // offset - IN_K_SCALE = 14, // scale - IN_K_COMPRESS_IDX = 15, - - IN_V_WEIGHT = 16, // weight - IN_V_BIAS = 17, // bias - IN_V_DEQSCALE = 18, // deq_scale - IN_V_OFFSET = 19, // offset - IN_V_SCALE = 20, // scale - IN_V_COMPRESS_IDX = 21, - - IN_ATTENTION_OUT_WEIGHT = 22, // weight - IN_ATTENTION_OUT_BIAS = 23, // bias - IN_ATTENTION_OUT_DEQSCALE = 24, // deq_scale - IN_ATTENTION_OUT_OFFSET = 25, // offset - IN_ATTENTION_OUT_SCALE = 26, // scale - IN_ATTENTION_OUT_COMPRESS_IDX = 27, - - IN_SELFOUT_NORM_WEIGHT = 28, // weight - IN_SELFOUT_NORM_BIAS = 29, // bias - IN_SELFOUT_NORM_NEW_WEIGHT = 30, // new weight - IN_SELFOUT_NORM_NEW_BIAS = 31, // new bias - - IN_MLP_W2_WEIGHT = 32, // weight - IN_MLP_W2_BIAS = 33, // bias - IN_MLP_W2_DEQSCALE = 34, // deq_scale - IN_MLP_W2_OFFSET = 35, // offset - IN_MLP_W2_SCALE = 36, // scale - IN_MLP_W2_COMPRESS_IDX = 37, - - IN_MLP_W1_WEIGHT = 38, // weight - IN_MLP_W1_BIAS = 39, // bias - IN_MLP_W1_DEQSCALE = 40, // deq_scale - IN_MLP_W1_OFFSET = 41, // offset - IN_MLP_W1_SCALE = 42, // scale - IN_MLP_W1_COMPRESS_IDX = 43, - - IN_MLP_CPROJ_WEIGHT = 44, // weight - IN_MLP_CPROJ_BIAS = 45, // bias - IN_MLP_CPROJ_DEQSCALE = 46, // deq_scale - IN_MLP_CPROJ_OFFSET = 47, // offset - IN_MLP_CPROJ_SCALE = 48, // scale - IN_MLP_CPROJ_COMPRESS_IDX = 49, -}; - class Qwen2DecoderLayerImpl : public BaseLayer { public: explicit Qwen2DecoderLayerImpl(const ModelContext& context); ~Qwen2DecoderLayerImpl() {}; - TransposeType check_transpose(at::Tensor& tensor); - - virtual void load_state_dict(const StateDict& state_dict) override; - - virtual void verify_loaded_weights() const override; - virtual void merge_loaded_weights() override; virtual int64_t init_layer() override; @@ -131,6 +65,12 @@ class Qwen2DecoderLayerImpl : public BaseLayer { int node_id = 0); private: + TransposeType check_transpose(at::Tensor& tensor); + + void initialize_quantization_parameters(); + + void initialize_linear_transpose_type(); + void build_node_variant_pack(atb_speed::Model::Node& node, torch::Tensor& x, torch::Tensor& cos_pos, diff --git a/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.cpp index 71bfb82f3..c4b22ecb7 100644 --- a/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.cpp @@ -29,31 +29,6 @@ namespace layer { const uint64_t WEIGHT_COUNT_PER_LAYER = 18; -static std::vector> WEIGHT_MAPPING = { - {IN_INPUT_NORM_WEIGHT, "norm1.weight"}, - {IN_POST_NORM_WEIGHT, "norm2.weight"}, - {IN_QKV_WEIGHT, "qkv.weight"}, - {IN_QKV_BIAS, "qkv.bias"}, - {IN_WATTENTION_OUT_WEIGHT, "attn.proj.weight"}, - {IN_WATTENTION_OUT_BIAS, "attn.proj.bias"}, - {IN_MLP_GATE_WEIGHT, "mlp.gate_proj.weight"}, - {IN_MLP_GATE_BIAS, "mlp.gate_proj.bias"}, - {IN_MLP_UP_WEIGHT, "mlp.up_proj.weight"}, - {IN_MLP_UP_BIAS, "mlp.up_proj.bias"}, - {IN_MLP_DOWN_WEIGHT, "mlp.down_proj.weight"}, - {IN_MLP_DOWN_BIAS, "mlp.down_proj.bias"}, -}; - -// {weight,dim} -static std::map WEIGHT_SHARD = { - {IN_WATTENTION_OUT_WEIGHT, 1}, - {IN_MLP_GATE_WEIGHT, 0}, - {IN_MLP_GATE_BIAS, 0}, - {IN_MLP_UP_WEIGHT, 0}, - {IN_MLP_UP_BIAS, 0}, - {IN_MLP_DOWN_WEIGHT, 1}, -}; - void Qwen2dot5VisionEncoderLayerImpl::param_from_args( atb_speed::qwen::VisionEncoderLayerParam& param, const ModelArgs& args, @@ -86,198 +61,22 @@ Qwen2dot5VisionEncoderLayerImpl::Qwen2dot5VisionEncoderLayerImpl( device_id_ = options.device().index(); placeholder_ = atb_speed::Utils::AtTensor2Tensor( torch::zeros({1}).to(device_).to(dtype_)); - at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); - for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - at_weight_tensors_[i] = torch::zeros({1}).to(options); - } -} - -void Qwen2dot5VisionEncoderLayerImpl::verify_loaded_weights() const { - for (const auto& [index, name] : WEIGHT_MAPPING) { - CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) - << "weight is not loaded for " << name; - } -} -void Qwen2dot5VisionEncoderLayerImpl::pad_mlp_weights() { - torch::Tensor weight = at_weight_tensors_[IN_MLP_GATE_WEIGHT]; - torch::Tensor bias = at_weight_tensors_[IN_MLP_GATE_BIAS]; - - int64_t tp_intermediate_size_half = weight.size(0) / 2; - int64_t remainder = tp_intermediate_size_half % 32; - int64_t tp_intermediate_size_half_pad; - if (remainder != 0) { - tp_intermediate_size_half_pad = - tp_intermediate_size_half + (32 - remainder); - } else { - tp_intermediate_size_half_pad = tp_intermediate_size_half; - } - - auto weight_split1 = weight.slice(0, 0, tp_intermediate_size_half); - auto weight_split2 = weight.slice(0, tp_intermediate_size_half); - auto bias_split1 = bias.slice(0, 0, tp_intermediate_size_half); - auto bias_split2 = bias.slice(0, tp_intermediate_size_half); - - auto weight_split1_padded = - pad_tensor(weight_split1, tp_intermediate_size_half_pad); - auto weight_split2_padded = - pad_tensor(weight_split2, tp_intermediate_size_half_pad); - auto bias_split1_padded = - pad_tensor(bias_split1, tp_intermediate_size_half_pad); - auto bias_split2_padded = - pad_tensor(bias_split2, tp_intermediate_size_half_pad); - - auto weight_padded = - torch::cat({weight_split1_padded, weight_split2_padded}, 0); - auto bias_padded = torch::cat({bias_split1_padded, bias_split2_padded}, 0); - at_weight_tensors_[IN_MLP_GATE_WEIGHT] = weight_padded; - at_weight_tensors_[IN_MLP_GATE_BIAS] = bias_padded; - - torch::Tensor down_weight = at_weight_tensors_[IN_MLP_DOWN_WEIGHT]; - - auto tp_intermediate_size = down_weight.size(1); - remainder = tp_intermediate_size % 32; - int64_t tp_intermediate_size_pad; - if (remainder != 0) { - tp_intermediate_size_pad = tp_intermediate_size + (32 - remainder); - } else { - tp_intermediate_size_pad = tp_intermediate_size; - } - - auto down_weight_padded = - pad_tensor(down_weight, tp_intermediate_size_pad, 1); - at_weight_tensors_[IN_MLP_DOWN_WEIGHT] = down_weight_padded; -} -void Qwen2dot5VisionEncoderLayerImpl::pad_qkv_weights() { - auto qkv_proj_weight = at_weight_tensors_[IN_QKV_WEIGHT]; - auto qkv_proj_bias = at_weight_tensors_[IN_QKV_BIAS]; - int num_heads_pre_rank = encode_param_.numAttentionHeadsPerRank; - int hidden_size = num_heads_pre_rank * 80 * encode_param_.worldSize; - - auto qkv_proj_weight_reshaped = - qkv_proj_weight.reshape({num_heads_pre_rank, 3, 80, hidden_size}); - - auto first_half = qkv_proj_weight_reshaped.slice(2, 0, 40); - auto second_half = qkv_proj_weight_reshaped.slice(2, 40, 80); - - auto first_half_padded = torch::nn::functional::pad( - first_half, torch::nn::functional::PadFuncOptions({0, 0, 0, 24})); - auto second_half_padded = torch::nn::functional::pad( - second_half, torch::nn::functional::PadFuncOptions({0, 0, 0, 24})); - - auto qkv_proj_weight_padded = - torch::cat({first_half_padded, second_half_padded}, 2); - auto qkv_proj_weight_final = qkv_proj_weight_padded.reshape( - {num_heads_pre_rank * 128 * 3, hidden_size}); - qkv_proj_weight_final = - at_npu::native::npu_format_cast(qkv_proj_weight_final, 2); - - auto qkv_proj_bias_reshaped = - qkv_proj_bias.reshape({num_heads_pre_rank, 3, 80}); - first_half = qkv_proj_bias_reshaped.slice(2, 0, 40); - second_half = qkv_proj_bias_reshaped.slice(2, 40, 80); - - first_half_padded = torch::nn::functional::pad( - first_half, torch::nn::functional::PadFuncOptions({0, 24})); - second_half_padded = torch::nn::functional::pad( - second_half, torch::nn::functional::PadFuncOptions({0, 24})); - auto qkv_proj_bias_padded = - torch::cat({first_half_padded, second_half_padded}, 2); - auto qkv_proj_bias_final = - qkv_proj_bias_padded.reshape({num_heads_pre_rank * 128 * 3}); - - at_weight_tensors_[IN_QKV_WEIGHT] = qkv_proj_weight_final; - at_weight_tensors_[IN_QKV_BIAS] = qkv_proj_bias_final; - - auto out_proj_weight = at_weight_tensors_[IN_WATTENTION_OUT_WEIGHT]; - - out_proj_weight = - torch::nn::functional::pad( - out_proj_weight.reshape({hidden_size, num_heads_pre_rank * 2, 40}), - torch::nn::functional::PadFuncOptions({0, 24, 0, 0})) - .reshape({hidden_size, num_heads_pre_rank * 128}); - at_weight_tensors_[IN_WATTENTION_OUT_WEIGHT] = out_proj_weight; + loader_ = std::make_unique( + WEIGHT_COUNT_PER_LAYER, context, encode_param_.numAttentionHeadsPerRank); } void Qwen2dot5VisionEncoderLayerImpl::merge_loaded_weights() { - // spilt pack qkv weight when enable tp - get_weights_col_packed_qkv(); - if (encode_param_.worldSize > 1) { - // merge qkv weight - auto new_qkv_weight = torch::cat({at_weight_tensors_[IN_VISION_Q_WEIGHT], - at_weight_tensors_[IN_VISION_K_WEIGHT], - at_weight_tensors_[IN_VISION_V_WEIGHT]}, - 0); - at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight; - at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros({1}).to(device_); - - // merge qkv bias - auto new_qkv_bias = torch::cat({at_weight_tensors_[IN_VISION_Q_BIAS], - at_weight_tensors_[IN_VISION_K_BIAS], - at_weight_tensors_[IN_VISION_V_BIAS]}, - 0); - at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias; - at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros({1}).to(device_); - } - // pad qkv weights - pad_qkv_weights(); - // merge gate up - auto new_mlp_weight = torch::cat({at_weight_tensors_[IN_MLP_GATE_WEIGHT], - at_weight_tensors_[IN_MLP_UP_WEIGHT]}, - 0); - at_weight_tensors_[IN_MLP_GATE_WEIGHT] = new_mlp_weight; - auto new_mlp_bias = torch::cat({at_weight_tensors_[IN_MLP_GATE_BIAS], - at_weight_tensors_[IN_MLP_UP_BIAS]}, - 0); - at_weight_tensors_[IN_MLP_GATE_BIAS] = new_mlp_bias; - at_weight_tensors_[IN_MLP_UP_BIAS] = torch::zeros({1}).to(device_); - // pad mlp weights - pad_mlp_weights(); + loader_->merge_loaded_weights(); + auto& at_weight_tensors = loader_->get_at_weight_tensors(); c10_npu::NPUCachingAllocator::emptyCache(); for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + LOG(INFO) << "device: " << at_weight_tensors[i].device(); atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[i]); } init_layer(); } -// tp spilt weight -void Qwen2dot5VisionEncoderLayerImpl::get_weights_col_packed_qkv() { - int rank = encode_param_.rank; - int worldSize = encode_param_.worldSize; - // split qkv weight - qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0); - qkv_bias = torch::chunk(at_weight_tensors_[IN_QKV_BIAS], 3, 0); - // weight - at_weight_tensors_[IN_VISION_Q_WEIGHT] = - (qkv_weight[0].chunk(worldSize, 0))[rank]; - at_weight_tensors_[IN_VISION_K_WEIGHT] = - (qkv_weight[1].chunk(worldSize, 0))[rank]; - at_weight_tensors_[IN_VISION_V_WEIGHT] = - (qkv_weight[2].chunk(worldSize, 0))[rank]; - // bias - at_weight_tensors_[IN_VISION_Q_BIAS] = - (qkv_bias[0].chunk(worldSize, 0))[rank]; - at_weight_tensors_[IN_VISION_K_BIAS] = - (qkv_bias[1].chunk(worldSize, 0))[rank]; - at_weight_tensors_[IN_VISION_V_BIAS] = - (qkv_bias[2].chunk(worldSize, 0))[rank]; -} - -void Qwen2dot5VisionEncoderLayerImpl::load_state_dict( - const StateDict& state_dict) { - for (const auto& [index, name] : WEIGHT_MAPPING) { - if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD[index]); - } else { - set_weight(state_dict, name, index); - } - } - get_weights_col_packed_qkv(); -} int64_t Qwen2dot5VisionEncoderLayerImpl::init_layer() { name_ = "qwen2_5_encoder_layer"; @@ -368,9 +167,6 @@ void Qwen2dot5VisionEncoderLayerImpl::build_node_variant_pack( CHECK_THROW(node.inTensors.at(i) == nullptr, model_name_ << "inTensor " << i << "is NULL"); node.variantPack.inTensors.at(i) = *node.inTensors.at(i); - // LOG(INFO) << model_name_ << "inTensors[" << i << "]:" - // << atb_speed::TensorUtil::TensorToString( - // node.variantPack.inTensors.at(i)); } node.variantPack.outTensors.at(0) = internal_tensors_; diff --git a/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.h index eec1f4376..44788f288 100644 --- a/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen2dot5_vision_encoder_layer_impl.h @@ -34,6 +34,7 @@ limitations under the License. #include "core/framework/model/model_args.h" #include "core/framework/model/model_input_params.h" #include "core/framework/state_dict/state_dict.h" +#include "loader/qwen2dot5_vision_encoder_loader.h" #include "nlohmann/json.hpp" #include "npu_base_layer.h" #include "pytorch/adapter/utils/utils.h" @@ -42,37 +43,12 @@ limitations under the License. namespace xllm { namespace layer { -enum VisionEncoderLayerTensorId : int { - IN_INPUT_NORM_WEIGHT = 0, - IN_POST_NORM_WEIGHT, - IN_QKV_WEIGHT, - IN_QKV_BIAS, - IN_WATTENTION_OUT_WEIGHT, - IN_WATTENTION_OUT_BIAS, - IN_MLP_GATE_WEIGHT, - IN_MLP_GATE_BIAS, - IN_MLP_UP_WEIGHT, - IN_MLP_UP_BIAS, - IN_MLP_DOWN_WEIGHT, - IN_MLP_DOWN_BIAS, - IN_VISION_Q_WEIGHT, - IN_VISION_Q_BIAS, - IN_VISION_K_WEIGHT, - IN_VISION_K_BIAS, - IN_VISION_V_WEIGHT, - IN_VISION_V_BIAS -}; - class Qwen2dot5VisionEncoderLayerImpl : public BaseLayer { public: explicit Qwen2dot5VisionEncoderLayerImpl(const ModelContext& context); ~Qwen2dot5VisionEncoderLayerImpl() {}; - void load_state_dict(const StateDict& state_dict) override; - - void verify_loaded_weights() const override; - void merge_loaded_weights() override; int64_t init_layer() override; @@ -97,8 +73,6 @@ class Qwen2dot5VisionEncoderLayerImpl : public BaseLayer { ModelInputParams& input_params, bool is_prefill); - void get_weights_col_packed_qkv(); - void param_from_args(atb_speed::qwen::VisionEncoderLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args); @@ -106,28 +80,6 @@ class Qwen2dot5VisionEncoderLayerImpl : public BaseLayer { int64_t init_node(atb_speed::Model::Node& node, atb_speed::qwen::VisionEncoderLayerParam& param); - void pad_qkv_weights(); - - void pad_mlp_weights(); - - torch::Tensor pad_tensor(const torch::Tensor& tensor, - int64_t target_shape, - int64_t dim = 0) { - int64_t pad_size = target_shape - tensor.size(dim); - if (tensor.dim() == 1) { - return torch::nn::functional::pad( - tensor, torch::nn::functional::PadFuncOptions({0, pad_size})); - } else if (tensor.dim() == 2) { - if (1 == dim) - return torch::nn::functional::pad( - tensor, torch::nn::functional::PadFuncOptions({0, pad_size, 0, 0})); - else - return torch::nn::functional::pad( - tensor, torch::nn::functional::PadFuncOptions({0, 0, 0, pad_size})); - } - return tensor; - } - atb_speed::Model::Node encode_node_; std::string model_name_; diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp index 96e15b2fb..e5374dd1b 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp @@ -29,152 +29,8 @@ limitations under the License. namespace xllm { namespace layer { -enum DecoderLayerTensorId : int { - IN_NORM_WEIGHT = 0, // weight - IN_NORM_BIAS = 1, // bias - IN_NORM_NEW_WEIGHT = 2, // new weight - IN_NORM_NEW_BIAS = 3, // new bias - - IN_Q_WEIGHT = 4, // weight - IN_Q_BIAS = 5, // bias - IN_Q_DEQSCALE = 6, // deq_scale - IN_Q_OFFSET = 7, // offset - IN_Q_SCALE = 8, // scale - IN_Q_COMPRESS_IDX = 9, - - IN_K_WEIGHT = 10, // weight - IN_K_BIAS = 11, // bias - IN_K_DEQSCALE = 12, // deq_scale - IN_K_OFFSET = 13, // offset - IN_K_SCALE = 14, // scale - IN_K_COMPRESS_IDX = 15, - - IN_V_WEIGHT = 16, // weight - IN_V_BIAS = 17, // bias - IN_V_DEQSCALE = 18, // deq_scale - IN_V_OFFSET = 19, // offset - IN_V_SCALE = 20, // scale - IN_V_COMPRESS_IDX = 21, - - IN_ATTENTION_OUT_WEIGHT = 22, // weight - IN_ATTENTION_OUT_BIAS = 23, // bias - IN_ATTENTION_OUT_DEQSCALE = 24, // deq_scale - IN_ATTENTION_OUT_OFFSET = 25, // offset - IN_ATTENTION_OUT_SCALE = 26, // scale - IN_ATTENTION_OUT_COMPRESS_IDX = 27, - - IN_SELFOUT_NORM_WEIGHT = 28, // weight - IN_SELFOUT_NORM_BIAS = 29, // bias - IN_SELFOUT_NORM_NEW_WEIGHT = 30, // new weight - IN_SELFOUT_NORM_NEW_BIAS = 31, // new bias - - IN_MLP_W2_WEIGHT = 32, // weight - IN_MLP_W2_BIAS = 33, // bias - IN_MLP_W2_DEQSCALE = 34, // deq_scale - IN_MLP_W2_OFFSET = 35, // offset - IN_MLP_W2_SCALE = 36, // scale - IN_MLP_W2_COMPRESS_IDX = 37, - - IN_MLP_W1_WEIGHT = 38, // weight - IN_MLP_W1_BIAS = 39, // bias - IN_MLP_W1_DEQSCALE = 40, // deq_scale - IN_MLP_W1_OFFSET = 41, // offset - IN_MLP_W1_SCALE = 42, // scale - IN_MLP_W1_COMPRESS_IDX = 43, - - IN_MLP_CPROJ_WEIGHT = 44, // weight - IN_MLP_CPROJ_BIAS = 45, // bias - IN_MLP_CPROJ_DEQSCALE = 46, // deq_scale - IN_MLP_CPROJ_OFFSET = 47, // offset - IN_MLP_CPROJ_SCALE = 48, // scale - IN_MLP_CPROJ_COMPRESS_IDX = 49, - - IN_QKV_SCALE_FILL = 50, - IN_QKV_OFFSET_FILL = 51, - IN_MLP_SCALE_FILL = 52, - IN_MLP_OFFSET_FILL = 53, - Q_NORM_WEIGHT = 54, - K_NORM_WEIGHT = 55, -}; - const uint64_t WEIGHT_COUNT_PER_LAYER = 56; -static std::vector> WEIGHT_MAPPING = { - {IN_NORM_WEIGHT, "input_layernorm.weight"}, - {IN_Q_WEIGHT, "self_attn.q_proj.weight"}, - {IN_K_WEIGHT, "self_attn.k_proj.weight"}, - {IN_V_WEIGHT, "self_attn.v_proj.weight"}, - {IN_ATTENTION_OUT_WEIGHT, "self_attn.o_proj.weight"}, - {IN_SELFOUT_NORM_WEIGHT, "post_attention_layernorm.weight"}, - {IN_MLP_W2_WEIGHT, "mlp.gate_proj.weight"}, - {IN_MLP_W1_WEIGHT, "mlp.up_proj.weight"}, - {IN_MLP_CPROJ_WEIGHT, "mlp.down_proj.weight"}, - {Q_NORM_WEIGHT, "self_attn.q_norm.weight"}, - {K_NORM_WEIGHT, "self_attn.k_norm.weight"}}; - -static std::vector> WEIGHT_MAPPING_W8A8 = { - {IN_NORM_WEIGHT, "input_layernorm.weight"}, - {IN_Q_WEIGHT, "self_attn.q_proj.weight"}, - {IN_Q_BIAS, "self_attn.q_proj.quant_bias"}, - {IN_Q_DEQSCALE, "self_attn.q_proj.deq_scale"}, - {IN_Q_OFFSET, "self_attn.q_proj.input_offset"}, - {IN_Q_SCALE, "self_attn.q_proj.input_scale"}, - {IN_K_WEIGHT, "self_attn.k_proj.weight"}, - {IN_K_BIAS, "self_attn.k_proj.quant_bias"}, - {IN_K_DEQSCALE, "self_attn.k_proj.deq_scale"}, - {IN_K_OFFSET, "self_attn.k_proj.input_offset"}, - {IN_K_SCALE, "self_attn.k_proj.input_scale"}, - {IN_V_WEIGHT, "self_attn.v_proj.weight"}, - {IN_V_BIAS, "self_attn.v_proj.quant_bias"}, - {IN_V_DEQSCALE, "self_attn.v_proj.deq_scale"}, - {IN_V_OFFSET, "self_attn.v_proj.input_offset"}, - {IN_V_SCALE, "self_attn.v_proj.input_scale"}, - {IN_ATTENTION_OUT_WEIGHT, "self_attn.o_proj.weight"}, - {IN_ATTENTION_OUT_BIAS, "self_attn.o_proj.quant_bias"}, - {IN_ATTENTION_OUT_DEQSCALE, "self_attn.o_proj.deq_scale"}, - {IN_ATTENTION_OUT_OFFSET, "self_attn.o_proj.input_offset"}, - {IN_ATTENTION_OUT_SCALE, "self_attn.o_proj.input_scale"}, - {IN_SELFOUT_NORM_WEIGHT, "post_attention_layernorm.weight"}, - {IN_MLP_W2_WEIGHT, "mlp.gate_proj.weight"}, - {IN_MLP_W2_BIAS, "mlp.gate_proj.quant_bias"}, - {IN_MLP_W2_DEQSCALE, "mlp.gate_proj.deq_scale"}, - {IN_MLP_W2_OFFSET, "mlp.gate_proj.input_offset"}, - {IN_MLP_W2_SCALE, "mlp.gate_proj.input_scale"}, - {IN_MLP_W1_WEIGHT, "mlp.up_proj.weight"}, - {IN_MLP_W1_BIAS, "mlp.up_proj.quant_bias"}, - {IN_MLP_W1_DEQSCALE, "mlp.up_proj.deq_scale"}, - {IN_MLP_W1_OFFSET, "mlp.up_proj.input_offset"}, - {IN_MLP_W1_SCALE, "mlp.up_proj.input_scale"}, - {IN_MLP_CPROJ_WEIGHT, "mlp.down_proj.weight"}, - {Q_NORM_WEIGHT, "self_attn.q_norm.weight"}, - {K_NORM_WEIGHT, "self_attn.k_norm.weight"}}; - -static std::map WEIGHT_SHARD = {{IN_Q_WEIGHT, 0}, - {IN_K_WEIGHT, 0}, - {IN_V_WEIGHT, 0}, - {IN_ATTENTION_OUT_WEIGHT, 1}, - {IN_MLP_W2_WEIGHT, 0}, - {IN_MLP_W1_WEIGHT, 0}, - {IN_MLP_CPROJ_WEIGHT, 1}}; - -static std::map WEIGHT_SHARD_W8A8 = {{IN_Q_WEIGHT, 0}, - {IN_Q_BIAS, 0}, - {IN_Q_DEQSCALE, 0}, - {IN_K_WEIGHT, 0}, - {IN_K_BIAS, 0}, - {IN_K_DEQSCALE, 0}, - {IN_V_WEIGHT, 0}, - {IN_V_BIAS, 0}, - {IN_V_DEQSCALE, 0}, - {IN_ATTENTION_OUT_WEIGHT, 1}, - {IN_MLP_W2_WEIGHT, 0}, - {IN_MLP_W2_BIAS, 0}, - {IN_MLP_W2_DEQSCALE, 0}, - {IN_MLP_W1_WEIGHT, 0}, - {IN_MLP_W1_BIAS, 0}, - {IN_MLP_W1_DEQSCALE, 0}, - {IN_MLP_CPROJ_WEIGHT, 1}}; - void Qwen3DecoderLayerImpl::param_from_args( atb_speed::qwen::QwenLayerParam& param, const ModelArgs& args, @@ -280,11 +136,9 @@ Qwen3DecoderLayerImpl::Qwen3DecoderLayerImpl(const ModelContext& context) param_from_args(prefill_param_, model_args, parallel_args, true); param_from_args(decode_param_, model_args, parallel_args, false); - at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; dtype_ = c10::typeMetaToScalarType(options.dtype()); - rank_id_ = parallel_args.rank(); prefill_tensor_storage_.resize(4); decode_tensor_storage_.resize(4); prefill_vector_storage_.resize(1); @@ -292,173 +146,26 @@ Qwen3DecoderLayerImpl::Qwen3DecoderLayerImpl(const ModelContext& context) placeholder_ = atb_speed::Utils::AtTensor2Tensor( torch::zeros({1}).to(device_).to(dtype_)); at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); - for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - at_weight_tensors_[i] = torch::zeros({1}).to(options); - } -} - -void Qwen3DecoderLayerImpl::verify_loaded_weights() const { - for (const auto& [index, name] : WEIGHT_MAPPING) { - CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) - << "weight is not loaded for " << name; - } + loader_ = std::make_unique( + WEIGHT_COUNT_PER_LAYER, + context, + prefill_param_.enableIntraLayerAddNorm || + prefill_param_.enableInterLayerAddNorm); } void Qwen3DecoderLayerImpl::merge_loaded_weights() { - if (quantize_type_.compare("w8a8") == 0) { - at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE] = - at_weight_tensors_[IN_ATTENTION_OUT_DEQSCALE].to(torch::kFloat32); - at_weight_tensors_[IN_Q_DEQSCALE] = - torch::cat({at_weight_tensors_[IN_Q_DEQSCALE], - at_weight_tensors_[IN_K_DEQSCALE], - at_weight_tensors_[IN_V_DEQSCALE]}, - 0) - .to(torch::kFloat32); - - at_weight_tensors_[IN_Q_BIAS] = torch::cat({at_weight_tensors_[IN_Q_BIAS], - at_weight_tensors_[IN_K_BIAS], - at_weight_tensors_[IN_V_BIAS]}, - 0) - .to(torch::kInt32); - - for (auto idx : {IN_K_DEQSCALE, - IN_V_DEQSCALE, - IN_K_BIAS, - IN_V_BIAS, - IN_K_OFFSET, - IN_V_OFFSET, - IN_K_SCALE, - IN_V_SCALE}) { - at_weight_tensors_[idx] = at_placeholder_; - } - - at_weight_tensors_[IN_MLP_W2_BIAS] = - torch::cat({at_weight_tensors_[IN_MLP_W2_BIAS], - at_weight_tensors_[IN_MLP_W1_BIAS]}, - 0); - - at_weight_tensors_[IN_MLP_W2_DEQSCALE] = - torch::cat({at_weight_tensors_[IN_MLP_W2_DEQSCALE], - at_weight_tensors_[IN_MLP_W1_DEQSCALE]}, - 0) - .to(torch::kFloat32); - - for (auto idx : {IN_MLP_W1_BIAS, - IN_MLP_W1_OFFSET, - IN_MLP_W1_SCALE, - IN_MLP_W1_DEQSCALE}) { - at_weight_tensors_[idx] = at_placeholder_; - } - - at_weight_tensors_[IN_Q_OFFSET] = - at_weight_tensors_[IN_Q_OFFSET].to(torch::kInt8).to(device_); - at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] = - at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] - .to(torch::kInt8) - .to(device_); - at_weight_tensors_[IN_MLP_W2_OFFSET] = - at_weight_tensors_[IN_MLP_W2_OFFSET].to(torch::kInt8).to(device_); - - if (rank_id_ != 0) { - torch::Tensor original_tensor = at_weight_tensors_[IN_ATTENTION_OUT_BIAS]; - auto shape = original_tensor.sizes(); - auto dtype = original_tensor.dtype(); - auto device = original_tensor.device(); - - at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = torch::zeros( - shape, torch::TensorOptions().dtype(dtype).device(device)); - } - } - - at_weight_tensors_[IN_Q_WEIGHT] = - torch::cat({at_weight_tensors_[IN_Q_WEIGHT], - at_weight_tensors_[IN_K_WEIGHT], - at_weight_tensors_[IN_V_WEIGHT]}, - 0) - .contiguous(); - - at_weight_tensors_[IN_MLP_W2_WEIGHT] = - torch::cat({at_weight_tensors_[IN_MLP_W2_WEIGHT], - at_weight_tensors_[IN_MLP_W1_WEIGHT]}, - 0) - .contiguous(); - - for (auto idx : - {IN_MLP_W1_WEIGHT, IN_K_WEIGHT, IN_V_WEIGHT, IN_K_BIAS, IN_V_BIAS}) { - at_weight_tensors_[idx] = at_placeholder_; - } - - if (prefill_param_.enableIntraLayerAddNorm || - prefill_param_.enableInterLayerAddNorm) { - if (quantize_type_.compare("w8a8") == 0) { - // quantize - torch::ScalarType weight_fill_dtype = torch::kBFloat16; - int64_t weight_attn_shape = at_weight_tensors_[IN_Q_WEIGHT].size(-1); - int64_t weight_mlp_shape = at_weight_tensors_[IN_MLP_W2_WEIGHT].size(-1); - at_weight_tensors_[IN_QKV_SCALE_FILL] = at_weight_tensors_[IN_Q_SCALE] - .repeat(weight_attn_shape) - .to(weight_fill_dtype); - at_weight_tensors_[IN_MLP_SCALE_FILL] = - at_weight_tensors_[IN_MLP_W2_SCALE] - .repeat(weight_mlp_shape) - .to(weight_fill_dtype); - at_weight_tensors_[IN_QKV_OFFSET_FILL] = at_weight_tensors_[IN_Q_OFFSET] - .repeat(weight_attn_shape) - .to(weight_fill_dtype); - at_weight_tensors_[IN_MLP_OFFSET_FILL] = - at_weight_tensors_[IN_MLP_W2_OFFSET] - .repeat(weight_mlp_shape) - .to(weight_fill_dtype); - } else { - // bfloat16 or float16 - for (auto idx : {IN_QKV_SCALE_FILL, - IN_QKV_OFFSET_FILL, - IN_MLP_SCALE_FILL, - IN_MLP_OFFSET_FILL}) { - at_weight_tensors_[idx] = at_placeholder_; - } - } - } - + loader_->merge_loaded_weights(); + auto& at_weight_tensors = loader_->get_at_weight_tensors(); c10_npu::NPUCachingAllocator::emptyCache(); for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { + LOG(INFO) << "device: " << at_weight_tensors[i].device(); atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[i]); } init_layer(); } -void Qwen3DecoderLayerImpl::load_state_dict(const StateDict& state_dict) { - if (quantize_type_.compare("w8a8") == 0) { - for (const auto& [index, name] : WEIGHT_MAPPING_W8A8) { - if (WEIGHT_SHARD_W8A8.find(index) != WEIGHT_SHARD_W8A8.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD_W8A8[index]); - } else { - set_weight(state_dict, name, index); - } - } - at_weight_tensors_[IN_NORM_BIAS] = - torch::zeros(at_weight_tensors_[IN_NORM_WEIGHT].sizes(), - at_weight_tensors_[IN_NORM_WEIGHT].options()) - .to(device_); - - at_weight_tensors_[IN_SELFOUT_NORM_BIAS] = - torch::zeros(at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].sizes(), - at_weight_tensors_[IN_SELFOUT_NORM_WEIGHT].options()) - .to(device_); - return; - } - - for (const auto& [index, name] : WEIGHT_MAPPING) { - if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD[index]); - } else { - set_weight(state_dict, name, index); - } - } -} - int64_t Qwen3DecoderLayerImpl::init_layer() { init_attn_mask(); name_ = "qwen3_decoder_layer"; diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h index c4e12fa06..d8c01c436 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h @@ -31,6 +31,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "framework/model_context.h" #include "framework/state_dict/state_dict.h" +#include "loader/qwen3_decoder_loader.h" #include "nlohmann/json.hpp" #include "npu_base_layer.h" #include "pytorch/adapter/utils/utils.h" @@ -39,7 +40,6 @@ limitations under the License. #include "xllm_kernels/core/include/atb_speed/log.h" #include "xllm_kernels/core/include/atb_speed/utils/model_factory.h" #include "xllm_kernels/models/qwen3/layer/decoder_layer.h" - namespace xllm { namespace layer { @@ -49,9 +49,9 @@ class Qwen3DecoderLayerImpl : public BaseLayer { ~Qwen3DecoderLayerImpl() {}; - virtual void load_state_dict(const StateDict& state_dict) override; + // virtual void load_state_dict(const StateDict& state_dict) override; - virtual void verify_loaded_weights() const override; + // virtual void verify_loaded_weights() const override; virtual void merge_loaded_weights() override; diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp index f1d7b2868..919209541 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp @@ -24,189 +24,8 @@ limitations under the License. namespace xllm { namespace layer { -enum DecoderLayerTensorId : int { - IN_INPUT_NORM_WEIGHT = 0, // [2048] - IN_INPUT_NORM_BIAS = 1, - IN_INPUT_NORM_NEW_WEIGHT = 2, - IN_INPUT_NORM_NEW_BIAS = 3, - - IN_QKV_WEIGHT_0 = 4, // [4096, 2048] - IN_QKV_BIAS_0 = 5, - IN_QKV_DESCALE_0 = 6, - IN_QKV_OFFSET_0 = 7, - IN_QKV_SCALE_0 = 8, - IN_QKV_COMPRESS_IDX_0 = 9, - - IN_QKV_WEIGHT_1 = 10, // [512, 2048] - IN_QKV_BIAS_1 = 11, - IN_QKV_DESCALE_1 = 12, - IN_QKV_OFFSET_1 = 13, - IN_QKV_SCALE_1 = 14, - IN_QKV_COMPRESS_IDX_1 = 15, - - IN_QKV_WEIGHT_2 = 16, // [512, 2048] - IN_QKV_BIAS_2 = 17, - IN_QKV_DESCALE_2 = 18, - IN_QKV_OFFSET_2 = 19, - IN_QKV_SCALE_2 = 20, - IN_QKV_COMPRESS_IDX_2 = 21, - - IN_ATTENTION_OUT_WEIGHT = 22, // [2048, 4096] - IN_ATTENTION_OUT_BIAS = 23, - IN_ATTENTION_OUT_DESCALE = 24, - IN_ATTENTION_OUT_OFFSET = 25, - IN_ATTENTION_OUT_SCALE = 26, - IN_ATTENTION_OUT_COMPRESS_IDX = 27, - - IN_Q_NORM_WEIGHT = 28, // [128] - IN_K_NORM_WEIGHT = 29, // [128] - - IN_SELFATTENTION_OUT_NORM_WEIGHT = 30, // [2048] - IN_SELFATTENTION_OUT_NORM_BIAS = 31, - IN_SELFATTENTION_OUT_NEW_NORM_WEIGHT = 32, - IN_SELFATTENTION_OUT_NEW_NORM_BIAS = 33, - - IN_BLOCK_SPARSE_MOE_GATE_WEIGHT = 34, // [128, 2048] - IN_BLOCK_SPARSE_MOE_GATE_BIAS = 35, - IN_BLOCK_SPARSE_MOE_GATE_DESCALE = 36, - IN_BLOCK_SPARSE_MOE_GATE_OFFSET = 37, - IN_BLOCK_SPARSE_MOE_GATE_SCALE = 38, - IN_BLOCK_SPARSE_MOE_GATE_COMPRESS_IDX = 39, - - IN_MLP_GATEUP_WEIGHT_EXPERT = 40, - IN_MLP_GATEUP_BIAS_EXPERT = 41, - IN_MLP_GATEUP_DESCALE_EXPERT = 42, - IN_MLP_GATEUP_OFFSET_EXPERT = 43, - IN_MLP_GATEUP_SCALE_EXPERT = 44, - IN_MLP_GATEUP_COMPRESS_IDX_EXPERT = 45, - - IN_MLP_DOWN_WEIGHT_EXPERT = 46, // [2048, 768] - IN_MLP_DOWN_BIAS_EXPERT = 47, - IN_MLP_DOWN_DESCALE_EXPERT = 48, - IN_MLP_DOWN_OFFSET_EXPERT = 49, - IN_MLP_DOWN_SCALE_EXPERT = 50, - IN_MLP_DOWN_COMPRESS_IDX_EXPERT = 51, - - IN_MLP_SHARED_GATEUP_WEIGHT = 52, - IN_MLP_SHARED_DOWN_WEIGHT = 53, - IN_MLP_SHARED_EXPERT_GATE = 54, -}; - static const uint64_t WEIGHT_COUNT_PER_LAYER = 55; -static const std::unordered_map WEIGHT_MAPPING = { - {"input_layernorm.weight", IN_INPUT_NORM_WEIGHT}, - - {"self_attn.q_proj.weight", IN_QKV_WEIGHT_0}, - - {"self_attn.k_proj.weight", IN_QKV_WEIGHT_1}, - - {"self_attn.v_proj.weight", IN_QKV_WEIGHT_2}, - - {"self_attn.o_proj.weight", IN_ATTENTION_OUT_WEIGHT}, - - {"self_attn.q_norm.weight", IN_Q_NORM_WEIGHT}, - {"self_attn.k_norm.weight", IN_K_NORM_WEIGHT}, - - {"post_attention_layernorm.weight", IN_SELFATTENTION_OUT_NORM_WEIGHT}, - - // MoE Gate - {"mlp.gate.weight", IN_BLOCK_SPARSE_MOE_GATE_WEIGHT}, - - // Expert MLP - Gate/Up projections - {"gate_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, - - {"up_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, - - // Expert MLP - Down projection - {"down_proj.weight", IN_MLP_DOWN_WEIGHT_EXPERT}, - -}; - -static const std::unordered_map WEIGHT_MAPPING_W8A8 = { - {"input_layernorm.weight", IN_INPUT_NORM_WEIGHT}, - {"input_layernorm.bias", IN_INPUT_NORM_NEW_BIAS}, - - {"self_attn.q_proj.weight", IN_QKV_WEIGHT_0}, - {"self_attn.q_proj.bias", IN_QKV_BIAS_0}, - {"self_attn.q_proj.deq_scale", IN_QKV_DESCALE_0}, - {"self_attn.q_proj.weight_offset", IN_QKV_OFFSET_0}, - {"self_attn.q_proj.weight_scale", IN_QKV_SCALE_0}, - - {"self_attn.k_proj.weight", IN_QKV_WEIGHT_1}, - {"self_attn.k_proj.bias", IN_QKV_BIAS_1}, - {"self_attn.k_proj.deq_scale", IN_QKV_DESCALE_1}, - {"self_attn.k_proj.weight_offset", IN_QKV_OFFSET_1}, - {"self_attn.k_proj.weight_scale", IN_QKV_SCALE_1}, - - {"self_attn.v_proj.weight", IN_QKV_WEIGHT_2}, - {"self_attn.v_proj.bias", IN_QKV_BIAS_2}, - {"self_attn.v_proj.deq_scale", IN_QKV_DESCALE_2}, - {"self_attn.v_proj.weight_offset", IN_QKV_OFFSET_2}, - {"self_attn.v_proj.weight_scale", IN_QKV_SCALE_2}, - - {"self_attn.o_proj.weight", IN_ATTENTION_OUT_WEIGHT}, - {"self_attn.o_proj.quant_bias", IN_ATTENTION_OUT_BIAS}, - {"self_attn.o_proj.deq_scale", IN_ATTENTION_OUT_DESCALE}, - {"self_attn.o_proj.weight_offset", IN_ATTENTION_OUT_OFFSET}, - {"self_attn.o_proj.weight_scale", IN_ATTENTION_OUT_SCALE}, - - {"self_attn.q_norm.weight", IN_Q_NORM_WEIGHT}, - {"self_attn.k_norm.weight", IN_K_NORM_WEIGHT}, - - {"post_attention_layernorm.weight", IN_SELFATTENTION_OUT_NORM_WEIGHT}, - {"post_attention_layernorm.bias", IN_SELFATTENTION_OUT_NEW_NORM_BIAS}, - - // MoE Gate - {"mlp.gate.weight", IN_BLOCK_SPARSE_MOE_GATE_WEIGHT}, - - {"gate_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, - {"gate_proj.weight_offset", IN_MLP_GATEUP_OFFSET_EXPERT}, - {"gate_proj.weight_scale", IN_MLP_GATEUP_SCALE_EXPERT}, - {"up_proj.weight", IN_MLP_GATEUP_WEIGHT_EXPERT}, - {"up_proj.weight_offset", IN_MLP_GATEUP_OFFSET_EXPERT}, - {"up_proj.weight_scale", IN_MLP_GATEUP_SCALE_EXPERT}, - - {"down_proj.weight", IN_MLP_DOWN_WEIGHT_EXPERT}, - {"down_proj.weight_offset", IN_MLP_DOWN_OFFSET_EXPERT}, - {"down_proj.weight_scale", IN_MLP_DOWN_SCALE_EXPERT}, -}; - -static const std::unordered_map> - SPECIAL_MULTI_ASSIGN_W8A8 = { - {"input_layernorm.weight", - {IN_INPUT_NORM_WEIGHT, IN_INPUT_NORM_NEW_WEIGHT}}, - {"post_attention_layernorm.weight", - {IN_SELFATTENTION_OUT_NORM_WEIGHT, - IN_SELFATTENTION_OUT_NEW_NORM_WEIGHT}}, -}; - -static const std::map WEIGHT_SHARD = { - {IN_QKV_WEIGHT_0, 0}, - {IN_QKV_WEIGHT_1, 0}, - {IN_QKV_WEIGHT_2, 0}, - {IN_ATTENTION_OUT_WEIGHT, 1}, - {IN_MLP_GATEUP_WEIGHT_EXPERT, 0}, - {IN_MLP_DOWN_WEIGHT_EXPERT, 1}, -}; - -static const std::map WEIGHT_SHARD_W8A8 = { - {IN_QKV_WEIGHT_0, 0}, - {IN_QKV_OFFSET_0, 0}, - {IN_QKV_SCALE_0, 0}, - {IN_QKV_WEIGHT_1, 0}, - {IN_QKV_OFFSET_1, 0}, - {IN_QKV_SCALE_1, 0}, - {IN_QKV_WEIGHT_2, 0}, - {IN_QKV_OFFSET_2, 0}, - {IN_QKV_SCALE_2, 0}, - {IN_ATTENTION_OUT_WEIGHT, 1}, - {IN_MLP_GATEUP_WEIGHT_EXPERT, 0}, - {IN_MLP_GATEUP_OFFSET_EXPERT, 0}, - {IN_MLP_GATEUP_SCALE_EXPERT, 0}, - {IN_MLP_DOWN_WEIGHT_EXPERT, 1}, -}; - Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context, const int32_t layer_id) : BaseLayer(context), @@ -234,17 +53,16 @@ Qwen3MoeDecoderLayerImpl::Qwen3MoeDecoderLayerImpl(const ModelContext& context, CHECK_EQ(parallel_args.world_size(), dp_size_ * dp_local_tp_size_); dp_local_tp_rank_ = parallel_args.rank() % dp_local_tp_size_; - n_kv_heads_ = static_cast(model_args.n_kv_heads().value()); - param_from_args(prefill_param_, model_args, parallel_args, true); param_from_args(decode_param_, model_args, parallel_args, false); + loader_ = + std::make_unique(WEIGHT_COUNT_PER_LAYER, context); initialize_tensors(options); } void Qwen3MoeDecoderLayerImpl::initialize_tensors( const torch::TensorOptions& options) { // initializ placeholder - at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); placeholder_vec_ = {1}; int_tensor_placeholder_ = torch::ones({1}).to(torch::kInt32).to(device_); @@ -252,16 +70,10 @@ void Qwen3MoeDecoderLayerImpl::initialize_tensors( block_tables_placeholder_ = torch::zeros({1, 1}).to(torch::kInt32).to(device_); tensor_placeholder_ = torch::zeros({1}).to(options); - resize_experts_weights(num_experts_per_partition_); + loader_->resize_experts_weights(num_experts_per_partition_); one_hot_ = torch::tensor({1}, torch::kInt32).to(device_); zero_hot_ = torch::tensor({0}, torch::kInt32).to(device_); - at_start_expert_id_ = - torch::tensor({start_expert_id_}, torch::kInt64).to(device_); - at_in_device_expert_count_ = - torch::tensor({num_experts_per_partition_ - 1}, torch::kInt64) - .to(device_); expert_group_ = torch::tensor({1}, torch::dtype(torch::kInt32)).to(device_); - initialize_weight_tensors(options); } void Qwen3MoeDecoderLayerImpl::param_from_args( @@ -276,37 +88,6 @@ void Qwen3MoeDecoderLayerImpl::param_from_args( initialize_quantization_parameters(param); } -void Qwen3MoeDecoderLayerImpl::resize_experts_weights( - int num_of_device_experts) { - experts_weights_["gate_proj.weight"] = - std::vector(num_of_device_experts); - experts_weights_["up_proj.weight"] = - std::vector(num_of_device_experts); - experts_weights_["down_proj.weight"] = - std::vector(num_of_device_experts); - if (quantize_type_.compare("w8a8_dynamic") == 0) { - experts_weights_["gate_proj.weight_offset"] = - std::vector(num_of_device_experts); - experts_weights_["up_proj.weight_offset"] = - std::vector(num_of_device_experts); - experts_weights_["down_proj.weight_offset"] = - std::vector(num_of_device_experts); - experts_weights_["gate_proj.weight_scale"] = - std::vector(num_of_device_experts); - experts_weights_["up_proj.weight_scale"] = - std::vector(num_of_device_experts); - experts_weights_["down_proj.weight_scale"] = - std::vector(num_of_device_experts); - } -} - -void Qwen3MoeDecoderLayerImpl::initialize_weight_tensors( - const torch::TensorOptions& options) { - for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - at_weight_tensors_[i] = torch::zeros({1}).to(options); - } -} - void Qwen3MoeDecoderLayerImpl::initialize_basic_parameters( atb_speed::qwen::MoeDecoderLayerParam& param, const ModelArgs& args, @@ -447,402 +228,17 @@ void Qwen3MoeDecoderLayerImpl::initialize_quantization_parameters( } } -void Qwen3MoeDecoderLayerImpl::load_state_dict(const StateDict& state_dict) { - for (const auto& [name, tensor] : state_dict) { - bool is_sharded = false; - int index = 0; - - if (absl::StartsWith(name, "mlp.experts")) { - process_expert_weights(state_dict, name, tensor); - continue; - } - - if (absl::StartsWith(name, "mlp") && !absl::StrContains(name, "gate.")) { - process_mlp_common_weights(state_dict, name, tensor); - continue; - } - - process_general_weights(state_dict, name, tensor); - } -} - -int Qwen3MoeDecoderLayerImpl::get_mapped_index( - const std::string& name, - const std::unordered_map& mapping) { - const auto it = mapping.find(name); - if (it == mapping.end()) { - LOG(ERROR) << "Missing mapping for: " << name; - return -1; - } - - return it->second; -} - -void Qwen3MoeDecoderLayerImpl::process_expert_weights( - const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - int expert_index = extract_expert_index(name); - if (expert_index < start_expert_id_ || expert_index > end_expert_id_) { - return; - } - - const std::string suffix = extract_endswith(name); - const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_MAPPING_W8A8 - : WEIGHT_MAPPING; - const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_SHARD_W8A8 - : WEIGHT_SHARD; - const int index = get_mapped_index(suffix, weight_mapping); - const int local_index = expert_index % num_experts_per_partition_; - const bool is_sharded = shard_map.count(index); - - torch::Tensor tmp_tensor = is_sharded - ? get_sharded_tensor(state_dict, - name, - shard_map.at(index), - ep_local_tp_rank_, - ep_local_tp_size_) - : tensor; - - experts_weights_[suffix][local_index] = tmp_tensor.clone(); -} - -void Qwen3MoeDecoderLayerImpl::process_mlp_common_weights( - const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_MAPPING_W8A8 - : WEIGHT_MAPPING; - const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_SHARD_W8A8 - : WEIGHT_SHARD; - const int index = get_mapped_index(name, weight_mapping); - const bool is_sharded = shard_map.count(index); - - torch::Tensor tmp_tensor = is_sharded - ? get_sharded_tensor(state_dict, - name, - shard_map.at(index), - dp_local_tp_rank_, - dp_local_tp_size_) - .to(device_) - : tensor.to(device_); - if (absl::StrContains(name, "down_proj")) { - at_weight_tensors_[index] = tmp_tensor; - } else { - shared_experts_weights_[name] = tmp_tensor; - } -} - -void Qwen3MoeDecoderLayerImpl::process_general_weights( - const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor) { - const auto& weight_mapping = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_MAPPING_W8A8 - : WEIGHT_MAPPING; - const auto& shard_map = (quantize_type_.compare("w8a8_dynamic") == 0) - ? WEIGHT_SHARD_W8A8 - : WEIGHT_SHARD; - - if (weight_mapping.find(name) == weight_mapping.end()) { - return; - } - - const int index = get_mapped_index(name, weight_mapping); - const bool is_sharded = shard_map.count(index); - torch::Tensor tmp_tensor; - int32_t tp_rank = dp_local_tp_rank_; - int32_t tp_size = dp_local_tp_size_; - - static const std::unordered_set qkv_tensor_indices = {IN_QKV_WEIGHT_1, - IN_QKV_WEIGHT_2, - IN_QKV_BIAS_1, - IN_QKV_BIAS_2, - IN_QKV_DESCALE_1, - IN_QKV_DESCALE_2, - IN_QKV_OFFSET_1, - IN_QKV_OFFSET_2, - IN_QKV_SCALE_1, - IN_QKV_SCALE_2}; - - if (qkv_tensor_indices.count(index) > 0) { - if (n_kv_heads_ < dp_local_tp_size_) { - int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_); - - tp_rank = tp_rank / repeat_times; - tp_size = n_kv_heads_; - } - } - if (is_sharded) { - tmp_tensor = get_sharded_tensor( - state_dict, name, shard_map.at(index), tp_rank, tp_size) - .to(device_); - } else { - tmp_tensor = tensor.to(device_); - } - - correct_tensor_dtype(tmp_tensor, name); - if (quantize_type_.compare("w8a8_dynamic") == 0) { - auto it = SPECIAL_MULTI_ASSIGN_W8A8.find(name); - if (it != SPECIAL_MULTI_ASSIGN_W8A8.end()) { - for (int idx : it->second) { - at_weight_tensors_[idx] = tmp_tensor; - } - return; - } - } - at_weight_tensors_[index] = tmp_tensor; -} - -torch::Tensor Qwen3MoeDecoderLayerImpl::get_sharded_tensor( - const StateDict& state_dict, - const std::string& name, - int dim) { - if (parallel_args_.world_size() > 1) { - return state_dict.get_sharded_tensor( - name, dim, parallel_args_.rank(), parallel_args_.world_size()); - } else { - return state_dict.get_tensor(name); - } -} - -torch::Tensor Qwen3MoeDecoderLayerImpl::get_sharded_tensor( - const StateDict& state_dict, - const std::string& name, - int dim, - int loacal_tp_rank, - int local_tp_size) { - if (local_tp_size > 1) { - return state_dict.get_sharded_tensor( - name, dim, loacal_tp_rank, local_tp_size); - } else { - return state_dict.get_tensor(name); - } -} - -std::string Qwen3MoeDecoderLayerImpl::extract_endswith( - const std::string& input) { - std::vector parts; - std::stringstream ss(input); - std::string part; - while (std::getline(ss, part, '.')) { - parts.push_back(part); - } - if (parts.size() < 2) { - return ""; - } - std::string result = parts[parts.size() - 2] + "." + parts[parts.size() - 1]; - - return result; -} - -int Qwen3MoeDecoderLayerImpl::extract_expert_index(const std::string& name) { - std::string prefix = "experts."; - size_t pos = name.find(prefix); - if (pos != std::string::npos) { - pos += prefix.length(); - size_t end_pos = pos; - while (end_pos < name.length() && std::isdigit(name[end_pos])) { - ++end_pos; - } - if (end_pos > pos) { - return std::stoi(name.substr(pos, end_pos - pos)); - } - } - - return -1; -} - -void Qwen3MoeDecoderLayerImpl::verify_loaded_weights( - const std::string& prefix) const { - for (const auto& [name, index] : WEIGHT_MAPPING) { - if (name == "down_proj.weight" || name == "gate_proj.weight" || - name == "up_proj.weight") { - continue; - } - CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) - << "weight is not loaded for " << name; - } -} - void Qwen3MoeDecoderLayerImpl::merge_loaded_weights() { - merge_experts_weights(); - at_weight_tensors_[IN_QKV_WEIGHT_0] = - torch::cat({at_weight_tensors_[IN_QKV_WEIGHT_0], - at_weight_tensors_[IN_QKV_WEIGHT_1], - at_weight_tensors_[IN_QKV_WEIGHT_2]}, - 0) - .contiguous(); - at_weight_tensors_[IN_QKV_WEIGHT_1] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_WEIGHT_2] = - torch::zeros({1}, torch::kFloat16).to(device_); - - if (quantize_type_.compare("w8a8_dynamic") == 0) { - at_weight_tensors_[IN_QKV_BIAS_0] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_BIAS_1] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_BIAS_2] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_ATTENTION_OUT_BIAS] = - torch::zeros({1}, torch::kFloat16).to(device_); - - at_weight_tensors_[IN_QKV_DESCALE_0] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_DESCALE_1] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_DESCALE_2] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_ATTENTION_OUT_DESCALE] = - torch::zeros({1}, torch::kFloat16).to(device_); - - at_weight_tensors_[IN_QKV_OFFSET_0] = - torch::cat({at_weight_tensors_[IN_QKV_OFFSET_0], - at_weight_tensors_[IN_QKV_OFFSET_1], - at_weight_tensors_[IN_QKV_OFFSET_2]}, - 0) - .contiguous() - .view(-1); - at_weight_tensors_[IN_QKV_OFFSET_1] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_OFFSET_2] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_ATTENTION_OUT_OFFSET] = - at_weight_tensors_[IN_ATTENTION_OUT_OFFSET].contiguous().view(-1); - - at_weight_tensors_[IN_QKV_SCALE_0] = - torch::cat({at_weight_tensors_[IN_QKV_SCALE_0], - at_weight_tensors_[IN_QKV_SCALE_1], - at_weight_tensors_[IN_QKV_SCALE_2]}, - 0) - .contiguous() - .view(-1); - at_weight_tensors_[IN_QKV_SCALE_1] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_QKV_SCALE_2] = - torch::zeros({1}, torch::kFloat16).to(device_); - at_weight_tensors_[IN_ATTENTION_OUT_SCALE] = - at_weight_tensors_[IN_ATTENTION_OUT_SCALE].contiguous().view(-1); - } - + loader_->merge_loaded_weights(); + auto& at_weight_tensors = loader_->get_at_weight_tensors(); c10_npu::NPUCachingAllocator::emptyCache(); for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[i]); } init_layer(); } -torch::Tensor Qwen3MoeDecoderLayerImpl::convert_fp16_to_int64( - const torch::Tensor& fp16_tensor) { - auto float_tensor = fp16_tensor.to(torch::kFloat32); - auto int32_tensor = float_tensor.view(torch::kInt32); - auto int64_tensor = int32_tensor.to(torch::kInt64); - return int64_tensor; -} - -void Qwen3MoeDecoderLayerImpl::convert_descaled_weights_to_float() { - auto convert_to_float = [this](int index) { - at_weight_tensors_[index] = at_weight_tensors_[index].to(torch::kFloat32); - }; - convert_to_float(IN_ATTENTION_OUT_DESCALE); -} - -void Qwen3MoeDecoderLayerImpl::merge_experts_weights() { - if (experts_weights_.count("gate_proj.weight") > 0) { - auto& gate_weight = experts_weights_["gate_proj.weight"]; - } - - if (experts_weights_.count("up_proj.weight") > 0) { - auto& up_weight = experts_weights_["up_proj.weight"]; - } - - try { - torch::Tensor mlp_gateup_weight; - if (quantize_type_.compare("w8a8_dynamic") == 0) { - mlp_gateup_weight = - merge_experts_weights(experts_weights_["gate_proj.weight"], - experts_weights_["up_proj.weight"], - /*transpose=*/true); - at_weight_tensors_[IN_MLP_GATEUP_OFFSET_EXPERT] = - merge_experts_weights(experts_weights_["gate_proj.weight_offset"], - experts_weights_["up_proj.weight_offset"]); - at_weight_tensors_[IN_MLP_GATEUP_SCALE_EXPERT] = - merge_experts_weights(experts_weights_["gate_proj.weight_scale"], - experts_weights_["up_proj.weight_scale"]); - } else { - mlp_gateup_weight = - merge_experts_weights(experts_weights_["gate_proj.weight"], - experts_weights_["up_proj.weight"], - /*transpose=*/false); - } - at_weight_tensors_[IN_MLP_GATEUP_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_gateup_weight, 2).contiguous(); - } catch (const std::exception& e) { - LOG(ERROR) << "[ERROR] Exception in gateup weight processing: " << e.what(); - throw; - } - - if (experts_weights_.count("down_proj.weight") > 0) { - auto& down_weight = experts_weights_["down_proj.weight"]; - } - - try { - torch::Tensor mlp_down_weight = - merge_experts_weights(experts_weights_["down_proj.weight"], - /*transpose=*/false); - - at_weight_tensors_[IN_MLP_DOWN_WEIGHT_EXPERT] = - at_npu::native::npu_format_cast(mlp_down_weight, 2).contiguous(); - - if (quantize_type_.compare("w8a8_dynamic") == 0) { - at_weight_tensors_[IN_MLP_DOWN_OFFSET_EXPERT] = - merge_experts_weights(experts_weights_["down_proj.weight_offset"]); - at_weight_tensors_[IN_MLP_DOWN_SCALE_EXPERT] = - merge_experts_weights(experts_weights_["down_proj.weight_scale"]); - } - } catch (const std::exception& e) { - LOG(ERROR) << "[ERROR] Exception in down weight processing: " << e.what(); - throw; - } -} - -torch::Tensor Qwen3MoeDecoderLayerImpl::merge_experts_weights( - std::vector& experts, - bool transpose) { - torch::Tensor merged_tensor = torch::stack(experts, 0).to(device_); - if (transpose) { - merged_tensor = merged_tensor.transpose(1, 2); - } - merged_tensor = merged_tensor.contiguous(); - experts.clear(); - - return merged_tensor; -} - -torch::Tensor Qwen3MoeDecoderLayerImpl::merge_experts_weights( - std::vector& experts_gate, - std::vector& experts_up, - bool transpose) { - for (size_t i = 0; i < experts_up.size(); ++i) { - experts_gate[i] = torch::cat({experts_gate[i], experts_up[i]}, 0); - } - torch::Tensor merged_tensor = torch::stack(experts_gate, 0).to(device_); - if (transpose) { - merged_tensor = merged_tensor.transpose(1, 2); - } - merged_tensor = merged_tensor.contiguous(); - experts_gate.clear(); - experts_up.clear(); - - return merged_tensor; -} - int64_t Qwen3MoeDecoderLayerImpl::init_layer() { name_ = "qwen3_moe_decoder_layer " + std::to_string(layer_id_); model_name_ = "Qwen3_Moe"; diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h index 6b74ac4ea..e73bad5fc 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h @@ -26,6 +26,7 @@ limitations under the License. #include "framework/model/npu_dp_ep_padding.h" #include "framework/quant_args.h" #include "framework/state_dict/state_dict.h" +#include "loader/qwen3_moe_decoder_loader.h" #include "npu_base_layer.h" #include "xllm_kernels/core/include/atb_speed/base/hosttensor_binder.h" #include "xllm_kernels/core/include/atb_speed/base/model.h" @@ -43,10 +44,6 @@ class Qwen3MoeDecoderLayerImpl : public BaseLayer { ~Qwen3MoeDecoderLayerImpl() {}; - virtual void load_state_dict(const StateDict& state_dict); - - virtual void verify_loaded_weights(const std::string& prefix) const; - virtual void merge_loaded_weights(); virtual int64_t init_layer() override; @@ -71,15 +68,11 @@ class Qwen3MoeDecoderLayerImpl : public BaseLayer { void initialize_tensors(const torch::TensorOptions& options); - void initialize_weight_tensors(const torch::TensorOptions& options); - void param_from_args(atb_speed::qwen::MoeDecoderLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, bool is_prefill); - void resize_experts_weights(int num_of_device_experts); - void initialize_basic_parameters(atb_speed::qwen::MoeDecoderLayerParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, @@ -101,68 +94,6 @@ class Qwen3MoeDecoderLayerImpl : public BaseLayer { void initialize_quantization_parameters( atb_speed::qwen::MoeDecoderLayerParam& param); - torch::Tensor get_sharded_tensor(const StateDict& state_dict, - const std::string& name, - int dim); - torch::Tensor get_sharded_tensor(const StateDict& state_dict, - const std::string& name, - int dim, - int local_tp_rank, - int local_tp_size); - - std::string extract_endswith(const std::string& input); - - void set_kv_weight(const StateDict& state_dict, - const std::string& tensor_name, - int weight_position, - int dim); - - int extract_expert_index(const std::string& name); - - void convert_descaled_weights_to_float(); - - torch::Tensor convert_fp16_to_int64(const torch::Tensor& fp16_tensor); - - void merge_shared_experts_weights(); - - void merge_experts_weights(); - - void squeeze_experts_weights(); - - void preprocess_linear_for_rope(); - - void process_expert_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - void process_shared_expert_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - void process_mlp_common_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - void process_general_weights(const StateDict& state_dict, - const std::string& name, - const torch::Tensor& tensor); - - int get_mapped_index(const std::string& name, - const std::unordered_map& mapping); - - torch::Tensor view_tensor(torch::Tensor weight, - const std::string& name, - bool pre_view); - - torch::Tensor trans_rope_weight(torch::Tensor weight); - - torch::Tensor merge_experts_weights(std::vector& experts, - bool transpose = false); - - torch::Tensor merge_experts_weights(std::vector& experts_up, - std::vector& experts_gate, - bool transpose = false); - int64_t init_node(atb_speed::Model::Node& node, atb_speed::qwen::MoeDecoderLayerParam& param); @@ -190,7 +121,6 @@ class Qwen3MoeDecoderLayerImpl : public BaseLayer { int32_t start_expert_id_; int32_t end_expert_id_; int32_t ep_rank_; - int32_t n_kv_heads_; int32_t dp_size_; int32_t dp_local_tp_size_; @@ -214,8 +144,6 @@ class Qwen3MoeDecoderLayerImpl : public BaseLayer { torch::Tensor one_hot_; torch::Tensor zero_hot_; torch::Tensor final_hidden_states_; - torch::Tensor at_start_expert_id_; - torch::Tensor at_in_device_expert_count_; std::vector int_placeholder_; diff --git a/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp index 074a36528..4546d0f37 100644 --- a/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.cpp @@ -28,51 +28,8 @@ limitations under the License. namespace xllm { namespace layer { -enum VisionEncoderLayerTensorId : int { - IN_INPUT_NORM_WEIGHT = 0, - IN_INPUT_NORM_BIAS, - IN_POST_NORM_WEIGHT, - IN_POST_NORM_BIAS, - IN_QKV_WEIGHT, - IN_QKV_BIAS, - IN_WATTENTION_OUT_WEIGHT, - IN_WATTENTION_OUT_BIAS, - IN_LINEAR_FC1_WEIGHT, - IN_LINEAR_FC1_BIAS, - IN_LINEAR_FC2_WEIGHT, - IN_LINEAR_FC2_BIAS, - IN_VISION_Q_WEIGHT, - IN_VISION_Q_BIAS, - IN_VISION_K_WEIGHT, - IN_VISION_K_BIAS, - IN_VISION_V_WEIGHT, - IN_VISION_V_BIAS -}; - const uint64_t WEIGHT_COUNT_PER_LAYER = 18; -static std::vector> WEIGHT_MAPPING = { - {IN_INPUT_NORM_WEIGHT, "norm1.weight"}, - {IN_INPUT_NORM_BIAS, "norm1.bias"}, - {IN_POST_NORM_WEIGHT, "norm2.weight"}, - {IN_POST_NORM_BIAS, "norm2.bias"}, - {IN_QKV_WEIGHT, "attn.qkv.weight"}, - {IN_QKV_BIAS, "attn.qkv.bias"}, - {IN_WATTENTION_OUT_WEIGHT, "attn.proj.weight"}, - {IN_WATTENTION_OUT_BIAS, "attn.proj.bias"}, - {IN_LINEAR_FC1_WEIGHT, "mlp.linear_fc1.weight"}, - {IN_LINEAR_FC1_BIAS, "mlp.linear_fc1.bias"}, - {IN_LINEAR_FC2_WEIGHT, "mlp.linear_fc2.weight"}, - {IN_LINEAR_FC2_BIAS, "mlp.linear_fc2.bias"}}; - -// {weight,dim} -static std::map WEIGHT_SHARD = { - {IN_WATTENTION_OUT_WEIGHT, 1}, - {IN_LINEAR_FC1_WEIGHT, 0}, - {IN_LINEAR_FC1_BIAS, 0}, - {IN_LINEAR_FC2_WEIGHT, 1}, -}; - void Qwen3VisionEncoderLayerImpl::param_from_args( atb_speed::qwen::VisionEncoderLayerParam& param, const ModelArgs& args, @@ -99,89 +56,27 @@ Qwen3VisionEncoderLayerImpl::Qwen3VisionEncoderLayerImpl( auto parallel_args = context.get_parallel_args(); auto options = context.get_tensor_options(); param_from_args(encode_param_, model_args, parallel_args); - at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); + // at_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); atb_weight_tensors_.resize(WEIGHT_COUNT_PER_LAYER); dtype_ = c10::typeMetaToScalarType(options.dtype()); device_id_ = options.device().index(); placeholder_ = atb_speed::Utils::AtTensor2Tensor( torch::zeros({1}).to(device_).to(dtype_)); - at_placeholder_ = torch::zeros({1}).to(device_).to(dtype_); - for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { - at_weight_tensors_[i] = torch::zeros({1}).to(options); - } -} - -void Qwen3VisionEncoderLayerImpl::verify_loaded_weights() const { - for (const auto& [index, name] : WEIGHT_MAPPING) { - CHECK(at_weight_tensors_[index].sizes() != std::vector({1})) - << "weight is not loaded for " << name; - } + loader_ = std::make_unique(WEIGHT_COUNT_PER_LAYER, + context); } void Qwen3VisionEncoderLayerImpl::merge_loaded_weights() { - // spilt pack qkv weight when enable tp - get_weights_col_packed_qkv(); - if (encode_param_.worldSize > 1) { - // merge qkv weight - auto new_qkv_weight = torch::cat({at_weight_tensors_[IN_VISION_Q_WEIGHT], - at_weight_tensors_[IN_VISION_K_WEIGHT], - at_weight_tensors_[IN_VISION_V_WEIGHT]}, - 0); - at_weight_tensors_[IN_QKV_WEIGHT] = new_qkv_weight; - at_weight_tensors_[IN_VISION_Q_WEIGHT] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_VISION_K_WEIGHT] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_VISION_V_WEIGHT] = torch::zeros({1}).to(device_); - - // merge qkv bias - auto new_qkv_bias = torch::cat({at_weight_tensors_[IN_VISION_Q_BIAS], - at_weight_tensors_[IN_VISION_K_BIAS], - at_weight_tensors_[IN_VISION_V_BIAS]}, - 0); - at_weight_tensors_[IN_QKV_BIAS] = new_qkv_bias; - at_weight_tensors_[IN_VISION_Q_BIAS] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_VISION_K_BIAS] = torch::zeros({1}).to(device_); - at_weight_tensors_[IN_VISION_V_BIAS] = torch::zeros({1}).to(device_); - } + loader_->merge_loaded_weights(); + auto& at_weight_tensors = loader_->get_at_weight_tensors(); c10_npu::NPUCachingAllocator::emptyCache(); for (int i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { atb_weight_tensors_[i] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[i]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[i]); } init_layer(); } -// tp spilt weight -void Qwen3VisionEncoderLayerImpl::get_weights_col_packed_qkv() { - int rank = encode_param_.rank; - int worldSize = encode_param_.worldSize; - // split qkv weight - qkv_weight = torch::chunk(at_weight_tensors_[IN_QKV_WEIGHT], 3, 0); - qkv_bias = torch::chunk(at_weight_tensors_[IN_QKV_BIAS], 3, 0); - // weight - at_weight_tensors_[IN_VISION_Q_WEIGHT] = - (qkv_weight[0].chunk(worldSize, 0))[rank]; - at_weight_tensors_[IN_VISION_K_WEIGHT] = - (qkv_weight[1].chunk(worldSize, 0))[rank]; - at_weight_tensors_[IN_VISION_V_WEIGHT] = - (qkv_weight[2].chunk(worldSize, 0))[rank]; - // bias - at_weight_tensors_[IN_VISION_Q_BIAS] = - (qkv_bias[0].chunk(worldSize, 0))[rank]; - at_weight_tensors_[IN_VISION_K_BIAS] = - (qkv_bias[1].chunk(worldSize, 0))[rank]; - at_weight_tensors_[IN_VISION_V_BIAS] = - (qkv_bias[2].chunk(worldSize, 0))[rank]; -} - -void Qwen3VisionEncoderLayerImpl::load_state_dict(const StateDict& state_dict) { - for (const auto& [index, name] : WEIGHT_MAPPING) { - if (WEIGHT_SHARD.find(index) != WEIGHT_SHARD.end()) { - set_weight(state_dict, name, index, WEIGHT_SHARD[index]); - } else { - set_weight(state_dict, name, index); - } - } -} int64_t Qwen3VisionEncoderLayerImpl::init_layer() { name_ = "qwen3_encoder_layer"; @@ -272,9 +167,6 @@ void Qwen3VisionEncoderLayerImpl::build_node_variant_pack( CHECK_THROW(node.inTensors.at(i) == nullptr, model_name_ << "inTensor " << i << "is NULL"); node.variantPack.inTensors.at(i) = *node.inTensors.at(i); - // LOG(INFO) << model_name_ << "inTensors[" << i << "]:" - // << atb_speed::TensorUtil::TensorToString( - // node.variantPack.inTensors.at(i)); } node.variantPack.outTensors.at(0) = internal_tensors_; diff --git a/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h index 5b501f5d4..9ecba219d 100755 --- a/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_vision_encoder_layer_impl.h @@ -35,6 +35,7 @@ limitations under the License. #include "core/framework/model/model_args.h" #include "core/framework/model/model_input_params.h" #include "core/framework/state_dict/state_dict.h" +#include "loader/qwen3_vision_encoder_loader.h" #include "nlohmann/json.hpp" #include "npu_base_layer.h" #include "pytorch/adapter/utils/utils.h" @@ -49,9 +50,9 @@ class Qwen3VisionEncoderLayerImpl : public BaseLayer { ~Qwen3VisionEncoderLayerImpl() {}; - void load_state_dict(const StateDict& state_dict) override; + // void load_state_dict(const StateDict& state_dict) override; - void verify_loaded_weights() const override; + // void verify_loaded_weights() const override; void merge_loaded_weights() override; diff --git a/xllm/core/layers/npu/npu_rms_norm_impl.cpp b/xllm/core/layers/npu/npu_rms_norm_impl.cpp index 3dba6c6a9..f8b234580 100644 --- a/xllm/core/layers/npu/npu_rms_norm_impl.cpp +++ b/xllm/core/layers/npu/npu_rms_norm_impl.cpp @@ -29,30 +29,19 @@ void RMSNormImpl::param_from_args(atb::infer::RmsNormParam& param, RMSNormImpl::RMSNormImpl(const ModelContext& context) : BaseLayer(context) { param_from_args(norm_param_, context.get_model_args()); - at_weight_tensors_.resize(1); + // at_weight_tensors_.resize(1); atb_weight_tensors_.resize(1); - - auto options = context.get_tensor_options(); - dtype_ = c10::typeMetaToScalarType(options.dtype()); - at_weight_tensors_[0] = torch::zeros({1}).to(options); -} - -void RMSNormImpl::verify_loaded_weights(const std::string weight_str) const { - CHECK(at_weight_tensors_[0].sizes() != std::vector({1})) - << "final norm weight is not loaded for " << weight_str; + loader_ = std::make_unique((uint64_t)1, // 1 weight + context); } void RMSNormImpl::merge_loaded_weights() { + auto& at_weight_tensors = loader_->get_at_weight_tensors(); atb_weight_tensors_[0] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[0]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[0]); init_layer(); } -void RMSNormImpl::load_state_dict(const StateDict& state_dict) { - set_weight(state_dict, "weight", 0); - at_weight_tensors_[0] = at_weight_tensors_[0].to(dtype_); -} - int64_t RMSNormImpl::init_layer() { name_ = "rms_norm_layer"; model_name_ = "llm"; diff --git a/xllm/core/layers/npu/npu_rms_norm_impl.h b/xllm/core/layers/npu/npu_rms_norm_impl.h index a39644cac..5aa33b5e1 100644 --- a/xllm/core/layers/npu/npu_rms_norm_impl.h +++ b/xllm/core/layers/npu/npu_rms_norm_impl.h @@ -31,6 +31,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "framework/model_context.h" #include "framework/state_dict/state_dict.h" +#include "loader/rms_norm_loader.h" #include "nlohmann/json.hpp" #include "npu_base_layer.h" #include "pytorch/adapter/utils/utils.h" @@ -48,9 +49,9 @@ class RMSNormImpl : public BaseLayer { ~RMSNormImpl() {}; - void load_state_dict(const StateDict& state_dict) override; + // void load_state_dict(const StateDict& state_dict) override; - void verify_loaded_weights(const std::string weight_str) const; + // void verify_loaded_weights(const std::string weight_str) const; void merge_loaded_weights() override; diff --git a/xllm/core/layers/npu/npu_siglip_encoder_layer_impl.cpp b/xllm/core/layers/npu/npu_siglip_encoder_layer_impl.cpp index 95bc05c50..a81c4a47d 100644 --- a/xllm/core/layers/npu/npu_siglip_encoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_siglip_encoder_layer_impl.cpp @@ -27,6 +27,7 @@ SiglipEncoderLayerUpImpl::SiglipEncoderLayerUpImpl(const ModelContext& context, model_args_(context.get_model_args()), options_(context.get_tensor_options()), prefix_(prefix) { + loader_ = std::make_unique(context); build_graph(prefix); } @@ -142,24 +143,8 @@ void SiglipEncoderLayerUpImpl::build_graph(const std::string& prefix) { } void SiglipEncoderLayerUpImpl::load_state_dict(const StateDict& state_dict) { - const std::set key_names = {"layer_norm1.weight", - "layer_norm1.bias", - "self_attn.q_proj.weight", - "self_attn.q_proj.bias", - "self_attn.k_proj.weight", - "self_attn.k_proj.bias", - "self_attn.v_proj.weight", - "self_attn.v_proj.bias"}; - - atb_torch::TorchTensorMap weights_map; - for (const auto& [name, tensor] : state_dict) { - if (key_names.find(name) == key_names.end()) continue; - - auto weight_npu = tensor.to(options_); - - weights_.push_back(weight_npu); - weights_map[name] = weight_npu; - } + loader_->load_state_dict(state_dict); + auto weights_map = loader_->get_weights_map(); graph_.SetWeights(weights_map); } @@ -189,6 +174,8 @@ NpuSiglipEncoderLayerDownImpl::NpuSiglipEncoderLayerDownImpl( model_args_(context.get_model_args()), options_(context.get_tensor_options()), prefix_(prefix) { + loader_ = std::make_unique(context); + build_graph(prefix); } @@ -297,24 +284,8 @@ void NpuSiglipEncoderLayerDownImpl::build_graph(const std::string& prefix) { void NpuSiglipEncoderLayerDownImpl::load_state_dict( const StateDict& state_dict) { - const std::set key_names = {"self_attn.out_proj.weight", - "self_attn.out_proj.bias", - "layer_norm2.weight", - "layer_norm2.bias", - "mlp.fc1.weight", - "mlp.fc1.bias", - "mlp.fc2.weight", - "mlp.fc2.bias"}; - - atb_torch::TorchTensorMap weights_map; - for (const auto& [name, tensor] : state_dict) { - if (key_names.find(name) == key_names.end()) continue; - - auto weight_npu = tensor.to(options_); - - weights_.push_back(weight_npu); - weights_map[name] = weight_npu; - } + loader_->load_state_dict(state_dict); + auto& weights_map = loader_->get_weights_map(); graph_.SetWeights(weights_map); } diff --git a/xllm/core/layers/npu/npu_siglip_encoder_layer_impl.h b/xllm/core/layers/npu/npu_siglip_encoder_layer_impl.h index edc01fccf..a9ba679aa 100644 --- a/xllm/core/layers/npu/npu_siglip_encoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_siglip_encoder_layer_impl.h @@ -19,6 +19,7 @@ limitations under the License. #include "framework/model/model_input_params.h" #include "framework/model_context.h" #include "framework/state_dict/state_dict.h" +#include "loader/siglip_encoder_loader.h" #include "npu_base_layer.h" #include "xllm_kernels/pytorch/atb_torch/core/include/base_operation.h" #include "xllm_kernels/pytorch/atb_torch/core/include/graph_operation.h" diff --git a/xllm/core/layers/npu/npu_word_embedding_impl.cpp b/xllm/core/layers/npu/npu_word_embedding_impl.cpp index edc3a0b85..c0f8652a1 100644 --- a/xllm/core/layers/npu/npu_word_embedding_impl.cpp +++ b/xllm/core/layers/npu/npu_word_embedding_impl.cpp @@ -51,34 +51,19 @@ WordEmbeddingImpl::WordEmbeddingImpl(const ModelContext& context) auto options = context.get_tensor_options(); param_from_args(embedding_param_, model_args, parallel_args); - at_weight_tensors_.resize(1); atb_weight_tensors_.resize(1); atOutTensors_.resize(1); dtype_ = c10::typeMetaToScalarType(options.dtype()); - at_weight_tensors_[0] = torch::zeros({1}).to(options); -} - -void WordEmbeddingImpl::verify_loaded_weights( - const std::string weight_str) const { - CHECK(at_weight_tensors_[0].sizes() != std::vector({1})) - << "weight is not loaded for " << weight_str; + loader_ = std::make_unique(1, context); } void WordEmbeddingImpl::merge_loaded_weights() { + auto& at_weight_tensors = loader_->get_at_weight_tensors(); atb_weight_tensors_[0] = - atb_speed::Utils::AtTensor2Tensor(at_weight_tensors_[0]); + atb_speed::Utils::AtTensor2Tensor(at_weight_tensors[0]); init_layer(); } -void WordEmbeddingImpl::load_state_dict(const StateDict& state_dict) { - if (dp_size_ > 1) { - set_weight( - state_dict, "weight", 0, 1, dp_local_tp_rank_, dp_local_tp_size_); - } else { - set_weight(state_dict, "weight", 0, 1); - } -} - int64_t WordEmbeddingImpl::init_layer() { BaseLayer::name_ = "word_embedding_layer"; modelName_ = "llm"; diff --git a/xllm/core/layers/npu/npu_word_embedding_impl.h b/xllm/core/layers/npu/npu_word_embedding_impl.h index b11ff394f..af79f42ab 100644 --- a/xllm/core/layers/npu/npu_word_embedding_impl.h +++ b/xllm/core/layers/npu/npu_word_embedding_impl.h @@ -28,6 +28,7 @@ limitations under the License. #include "atb/atb_infer.h" #include "framework/model/model_input_params.h" +#include "loader/word_embedding_loader.h" #include "nlohmann/json.hpp" #include "npu_base_layer.h" #include "pytorch/adapter/utils/utils.h" @@ -46,10 +47,6 @@ class WordEmbeddingImpl : public BaseLayer { ~WordEmbeddingImpl() {}; - void load_state_dict(const StateDict& state_dict) override; - - void verify_loaded_weights(const std::string weight_str) const; - void merge_loaded_weights() override; void param_from_args(atb_speed::common::WordEmbeddingParam& param,