Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions xllm/core/layers/npu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ cc_library(
npu_siglip_encoder_layer_impl.h
../common/rotary_embedding_util.h
rotary_embedding.h
loader/qwen3_decoder_loader.h
loader/qwen2_decoder_loader.h
loader/qwen3_moe_decoder_loader.h
loader/word_embedding_loader.h
loader/lm_head_loader.h
loader/column_parallel_linear_loader.h
loader/deepseek_v2_decoder_loader.h
loader/glm4_moe_decoder_loader.h
loader/llama_decoder_loader.h
loader/qwen2dot5_vision_encoder_loader.h
loader/qwen3_vision_encoder_loader.h
loader/rms_norm_loader.h
loader/siglip_encoder_loader.h
loader/base_loader.h
SRCS
npu_word_embedding_impl.cpp
npu_pos_embedding_impl.cpp
Expand All @@ -53,6 +67,20 @@ cc_library(
npu_siglip_encoder_layer_impl.cpp
../common/rotary_embedding_util.cpp
rotary_embedding.cpp
loader/qwen3_decoder_loader.cpp
loader/qwen2_decoder_loader.cpp
loader/qwen3_moe_decoder_loader.cpp
loader/word_embedding_loader.cpp
loader/lm_head_loader.cpp
loader/column_parallel_linear_loader.cpp
loader/deepseek_v2_decoder_loader.cpp
loader/glm4_moe_decoder_loader.cpp
loader/llama_decoder_loader.cpp
loader/qwen2dot5_vision_encoder_loader.cpp
loader/qwen3_vision_encoder_loader.cpp
loader/rms_norm_loader.cpp
loader/siglip_encoder_loader.cpp
loader/base_loader.cpp
DEPS
"-Wl,--whole-archive"
"-Wl,--no-whole-archive"
Expand Down
146 changes: 146 additions & 0 deletions xllm/core/layers/npu/loader/base_loader.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "base_loader.h"

namespace xllm {
namespace layer {

BaseLoader::BaseLoader(uint64_t weight_count, const ModelContext& context)
: weight_count_(weight_count),
parallel_args_(context.get_parallel_args()),
device_(context.get_tensor_options().device()) {
auto quant_args = context.get_quant_args();
if (!quant_args.quantize_type().empty()) {
quantize_type_ = quant_args.quantize_type();
}

if (!quant_args.torch_dtype().empty()) {
torch_dtype_ = quant_args.torch_dtype();
}

dp_size_ = parallel_args_.dp_size();
dp_local_tp_size_ = parallel_args_.world_size() / dp_size_;
dp_rank_ = parallel_args_.rank() / dp_local_tp_size_;
CHECK_EQ(parallel_args_.world_size(), dp_size_ * dp_local_tp_size_);
dp_local_tp_rank_ = parallel_args_.rank() % dp_local_tp_size_;

at_weight_tensors_.resize(weight_count_);
}

void BaseLoader::set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position) {
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
at::Tensor mutable_tensor = tensor;
correct_tensor_dtype(mutable_tensor, tensor_name);
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
}
}
}

void BaseLoader::set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim) {
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
if (parallel_args_.world_size() <= 1) {
at::Tensor mutable_tensor = tensor;
correct_tensor_dtype(mutable_tensor, tensor_name);
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
} else {
at_weight_tensors_[weight_position] =
state_dict
.get_sharded_tensor(tensor_name,
/*dim=*/dim,
/*rank=*/parallel_args_.rank(),
/*world_size=*/parallel_args_.world_size())
.to(device_);
correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name);
}
}
}
}

void BaseLoader::set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim,
int rank,
int world_size) {
for (const auto& [name, tensor] : state_dict) {
if (absl::EndsWith(name, tensor_name)) {
if (world_size <= 1) {
at::Tensor mutable_tensor = tensor;
correct_tensor_dtype(mutable_tensor, tensor_name);
at_weight_tensors_[weight_position] = mutable_tensor.to(device_);
} else {
at_weight_tensors_[weight_position] =
state_dict
.get_sharded_tensor(tensor_name,
/*dim=*/dim,
/*rank=*/rank,
/*world_size=*/world_size)
.to(device_);
correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name);
}
}
}
}

void BaseLoader::correct_tensor_dtype(torch::Tensor& tensor,
const std::string& tensorName) {
if (absl::EndsWith(tensorName, "deq_scale") &&
(torch_dtype_.compare("bfloat16") == 0)) {
return;
}

if (tensor.dtype() != torch::kInt8 && tensor.dtype() != torch::kInt32 &&
tensor.dtype() != torch::kInt64) {
torch::Dtype dtype = string2dtype(torch_dtype_);
tensor = tensor.to(dtype);
}
}

torch::Dtype BaseLoader::string2dtype(const std::string& dtype_str) {
if (dtype_str.compare("float16") == 0) {
return torch::kFloat16;
} else if (dtype_str.compare("bfloat16") == 0) {
return torch::kBFloat16;
} else if (dtype_str.compare("float32") == 0) {
return torch::kFloat32;
} else if (dtype_str.compare("float64") == 0) {
return torch::kFloat64;
} else if (dtype_str.compare("int8") == 0) {
return torch::kInt8;
} else if (dtype_str.compare("int16") == 0) {
return torch::kInt16;
} else if (dtype_str.compare("int32") == 0) {
return torch::kInt32;
} else if (dtype_str.compare("int64") == 0) {
return torch::kInt64;
} else if (dtype_str.compare("uint8") == 0) {
return torch::kUInt8;
} else if (dtype_str.compare("bool") == 0) {
return torch::kBool;
}

LOG(FATAL) << "Unsupported dtype string: " << dtype_str;
}

} // namespace layer
} // namespace xllm
102 changes: 102 additions & 0 deletions xllm/core/layers/npu/loader/base_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <absl/strings/match.h>
#include <torch/torch.h>

#include "framework/eplb/expert_buffer_manager.h"
#include "framework/kv_cache/kv_cache.h"
#include "framework/model/model_input_params.h"
#include "framework/model_context.h"
#include "framework/state_dict/state_dict.h"
#include "xllm_kernels/pytorch/atb_torch/core/include/base_operation.h"
#include "xllm_kernels/pytorch/atb_torch/core/include/graph_operation.h"

namespace xllm {
namespace layer {

class BaseLoader {
public:
BaseLoader(uint64_t weight_count, const ModelContext& context);
virtual ~BaseLoader() = default;

virtual void load_state_dict(const StateDict& state_dict) {};
virtual void verify_loaded_weights() const {};
virtual void verify_loaded_weights(const std::string& prefix) const {};
virtual void merge_loaded_weights() {};
virtual void resize_experts_weights(int num_of_device_experts) {};
torch::Dtype string2dtype(const std::string& dtype_str);

void correct_tensor_dtype(torch::Tensor& tensor,
const std::string& tensorName);

std::vector<at::Tensor>& get_at_weight_tensors() {
return at_weight_tensors_;
}

std::unordered_map<std::string, std::vector<torch::Tensor>>&
get_experts_weight_tensors() {
return experts_weights_;
}

std::unique_ptr<ExpertBufferManager>& get_expert_shared_buffer() {
return shared_buffer_;
}

std::vector<int32_t>& get_device_expert_list() { return device_expert_list_; }

atb_torch::TorchTensorMap& get_weights_map() { return weights_map_; }

protected:
uint64_t weight_count_;
xllm::ParallelArgs parallel_args_;
std::string quantize_type_;
std::string torch_dtype_;
torch::ScalarType dtype_;
torch::TensorOptions options_;
std::vector<at::Tensor> at_weight_tensors_;
std::unique_ptr<ExpertBufferManager> shared_buffer_ = nullptr;
std::unordered_map<std::string, torch::Tensor> shared_experts_weights_;
std::unordered_map<std::string, std::vector<torch::Tensor>> experts_weights_;
std::vector<int32_t> device_expert_list_;
atb_torch::TorchTensorMap weights_map_;

at::Device device_;
int32_t dp_size_;
int32_t dp_local_tp_size_;
int32_t dp_rank_;
int32_t dp_local_tp_rank_;

void set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position);

void set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim);

void set_weight(const StateDict& state_dict,
const std::string& tensor_name,
int weight_position,
int dim,
int rank,
int world_size);
};

} // namespace layer
} // namespace xllm
47 changes: 47 additions & 0 deletions xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "column_parallel_linear_loader.h"

namespace xllm {
namespace layer {

ColumParallelLinearLoader::ColumParallelLinearLoader(
uint64_t weight_count,
const ModelContext& context)
: BaseLoader(weight_count, context) {
auto options = context.get_tensor_options();
dtype_ = torch::typeMetaToScalarType(options.dtype());
at_weight_tensors_[0] = torch::zeros({1}).to(options);
}

void ColumParallelLinearLoader::load_state_dict(const StateDict& state_dict) {
if (dp_size_ > 1) {
set_weight(
state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_);
} else {
set_weight(state_dict, "weight", 0, 0);
}
at_weight_tensors_[0] = at_weight_tensors_[0].to(dtype_);
}

void ColumParallelLinearLoader::verify_loaded_weights(
const std::string& weight_str) const {
CHECK(at_weight_tensors_[0].sizes() != std::vector<int64_t>({1}))
<< "weight is not loaded for " << weight_str;
}

} // namespace layer
} // namespace xllm
28 changes: 28 additions & 0 deletions xllm/core/layers/npu/loader/column_parallel_linear_loader.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "base_loader.h"

namespace xllm {
namespace layer {
class ColumParallelLinearLoader : public BaseLoader {
public:
ColumParallelLinearLoader(uint64_t weight_count, const ModelContext& context);

void load_state_dict(const StateDict& state_dict) override;
void verify_loaded_weights(const std::string& weight_str) const override;
};
} // namespace layer
} // namespace xllm
Loading
Loading