-
Notifications
You must be signed in to change notification settings - Fork 93
refactor: separate the weight loading in the npu layer class. #489
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| /* Copyright 2025 The xLLM Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://github.com/jd-opensource/xllm/blob/main/LICENSE | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| ==============================================================================*/ | ||
|
|
||
| #include "base_loader.h" | ||
|
|
||
| namespace xllm { | ||
| namespace layer { | ||
|
|
||
| BaseLoader::BaseLoader(uint64_t weight_count, const ModelContext& context) | ||
| : weight_count_(weight_count), | ||
| parallel_args_(context.get_parallel_args()), | ||
| device_(context.get_tensor_options().device()) { | ||
| auto quant_args = context.get_quant_args(); | ||
| if (!quant_args.quantize_type().empty()) { | ||
| quantize_type_ = quant_args.quantize_type(); | ||
| } | ||
|
|
||
| if (!quant_args.torch_dtype().empty()) { | ||
| torch_dtype_ = quant_args.torch_dtype(); | ||
| } | ||
|
|
||
| dp_size_ = parallel_args_.dp_size(); | ||
| dp_local_tp_size_ = parallel_args_.world_size() / dp_size_; | ||
| dp_rank_ = parallel_args_.rank() / dp_local_tp_size_; | ||
| CHECK_EQ(parallel_args_.world_size(), dp_size_ * dp_local_tp_size_); | ||
| dp_local_tp_rank_ = parallel_args_.rank() % dp_local_tp_size_; | ||
|
|
||
| at_weight_tensors_.resize(weight_count_); | ||
| } | ||
|
|
||
| void BaseLoader::set_weight(const StateDict& state_dict, | ||
| const std::string& tensor_name, | ||
| int weight_position) { | ||
| for (const auto& [name, tensor] : state_dict) { | ||
| if (absl::EndsWith(name, tensor_name)) { | ||
| at::Tensor mutable_tensor = tensor; | ||
| correct_tensor_dtype(mutable_tensor, tensor_name); | ||
| at_weight_tensors_[weight_position] = mutable_tensor.to(device_); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void BaseLoader::set_weight(const StateDict& state_dict, | ||
| const std::string& tensor_name, | ||
| int weight_position, | ||
| int dim) { | ||
| for (const auto& [name, tensor] : state_dict) { | ||
| if (absl::EndsWith(name, tensor_name)) { | ||
| if (parallel_args_.world_size() <= 1) { | ||
| at::Tensor mutable_tensor = tensor; | ||
| correct_tensor_dtype(mutable_tensor, tensor_name); | ||
| at_weight_tensors_[weight_position] = mutable_tensor.to(device_); | ||
| } else { | ||
| at_weight_tensors_[weight_position] = | ||
| state_dict | ||
| .get_sharded_tensor(tensor_name, | ||
| /*dim=*/dim, | ||
| /*rank=*/parallel_args_.rank(), | ||
| /*world_size=*/parallel_args_.world_size()) | ||
| .to(device_); | ||
| correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void BaseLoader::set_weight(const StateDict& state_dict, | ||
| const std::string& tensor_name, | ||
| int weight_position, | ||
Clement-Wang26 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| int dim, | ||
| int rank, | ||
| int world_size) { | ||
| for (const auto& [name, tensor] : state_dict) { | ||
| if (absl::EndsWith(name, tensor_name)) { | ||
| if (world_size <= 1) { | ||
| at::Tensor mutable_tensor = tensor; | ||
| correct_tensor_dtype(mutable_tensor, tensor_name); | ||
| at_weight_tensors_[weight_position] = mutable_tensor.to(device_); | ||
| } else { | ||
| at_weight_tensors_[weight_position] = | ||
| state_dict | ||
| .get_sharded_tensor(tensor_name, | ||
| /*dim=*/dim, | ||
| /*rank=*/rank, | ||
| /*world_size=*/world_size) | ||
| .to(device_); | ||
| correct_tensor_dtype(at_weight_tensors_[weight_position], tensor_name); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void BaseLoader::correct_tensor_dtype(torch::Tensor& tensor, | ||
| const std::string& tensorName) { | ||
| if (absl::EndsWith(tensorName, "deq_scale") && | ||
| (torch_dtype_.compare("bfloat16") == 0)) { | ||
| return; | ||
| } | ||
|
|
||
| if (tensor.dtype() != torch::kInt8 && tensor.dtype() != torch::kInt32 && | ||
| tensor.dtype() != torch::kInt64) { | ||
| torch::Dtype dtype = string2dtype(torch_dtype_); | ||
Clement-Wang26 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tensor = tensor.to(dtype); | ||
| } | ||
| } | ||
|
|
||
| torch::Dtype BaseLoader::string2dtype(const std::string& dtype_str) { | ||
| if (dtype_str.compare("float16") == 0) { | ||
Clement-Wang26 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return torch::kFloat16; | ||
| } else if (dtype_str.compare("bfloat16") == 0) { | ||
| return torch::kBFloat16; | ||
| } else if (dtype_str.compare("float32") == 0) { | ||
| return torch::kFloat32; | ||
| } else if (dtype_str.compare("float64") == 0) { | ||
| return torch::kFloat64; | ||
| } else if (dtype_str.compare("int8") == 0) { | ||
| return torch::kInt8; | ||
| } else if (dtype_str.compare("int16") == 0) { | ||
| return torch::kInt16; | ||
| } else if (dtype_str.compare("int32") == 0) { | ||
| return torch::kInt32; | ||
| } else if (dtype_str.compare("int64") == 0) { | ||
| return torch::kInt64; | ||
| } else if (dtype_str.compare("uint8") == 0) { | ||
| return torch::kUInt8; | ||
| } else if (dtype_str.compare("bool") == 0) { | ||
| return torch::kBool; | ||
| } | ||
|
|
||
| LOG(FATAL) << "Unsupported dtype string: " << dtype_str; | ||
| } | ||
|
|
||
| } // namespace layer | ||
| } // namespace xllm | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,102 @@ | ||
| /* Copyright 2025 The xLLM Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://github.com/jd-opensource/xllm/blob/main/LICENSE | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| ==============================================================================*/ | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <absl/strings/match.h> | ||
| #include <torch/torch.h> | ||
|
|
||
| #include "framework/eplb/expert_buffer_manager.h" | ||
| #include "framework/kv_cache/kv_cache.h" | ||
| #include "framework/model/model_input_params.h" | ||
| #include "framework/model_context.h" | ||
| #include "framework/state_dict/state_dict.h" | ||
| #include "xllm_kernels/pytorch/atb_torch/core/include/base_operation.h" | ||
| #include "xllm_kernels/pytorch/atb_torch/core/include/graph_operation.h" | ||
|
|
||
| namespace xllm { | ||
| namespace layer { | ||
|
|
||
| class BaseLoader { | ||
| public: | ||
| BaseLoader(uint64_t weight_count, const ModelContext& context); | ||
| virtual ~BaseLoader() = default; | ||
|
|
||
| virtual void load_state_dict(const StateDict& state_dict) {}; | ||
| virtual void verify_loaded_weights() const {}; | ||
| virtual void verify_loaded_weights(const std::string& prefix) const {}; | ||
| virtual void merge_loaded_weights() {}; | ||
| virtual void resize_experts_weights(int num_of_device_experts) {}; | ||
| torch::Dtype string2dtype(const std::string& dtype_str); | ||
|
|
||
| void correct_tensor_dtype(torch::Tensor& tensor, | ||
| const std::string& tensorName); | ||
|
|
||
| std::vector<at::Tensor>& get_at_weight_tensors() { | ||
| return at_weight_tensors_; | ||
| } | ||
|
|
||
| std::unordered_map<std::string, std::vector<torch::Tensor>>& | ||
| get_experts_weight_tensors() { | ||
| return experts_weights_; | ||
| } | ||
|
|
||
| std::unique_ptr<ExpertBufferManager>& get_expert_shared_buffer() { | ||
| return shared_buffer_; | ||
| } | ||
|
|
||
| std::vector<int32_t>& get_device_expert_list() { return device_expert_list_; } | ||
|
|
||
| atb_torch::TorchTensorMap& get_weights_map() { return weights_map_; } | ||
|
|
||
| protected: | ||
| uint64_t weight_count_; | ||
| xllm::ParallelArgs parallel_args_; | ||
| std::string quantize_type_; | ||
| std::string torch_dtype_; | ||
| torch::ScalarType dtype_; | ||
| torch::TensorOptions options_; | ||
| std::vector<at::Tensor> at_weight_tensors_; | ||
| std::unique_ptr<ExpertBufferManager> shared_buffer_ = nullptr; | ||
| std::unordered_map<std::string, torch::Tensor> shared_experts_weights_; | ||
| std::unordered_map<std::string, std::vector<torch::Tensor>> experts_weights_; | ||
| std::vector<int32_t> device_expert_list_; | ||
| atb_torch::TorchTensorMap weights_map_; | ||
|
|
||
| at::Device device_; | ||
| int32_t dp_size_; | ||
| int32_t dp_local_tp_size_; | ||
| int32_t dp_rank_; | ||
| int32_t dp_local_tp_rank_; | ||
|
|
||
| void set_weight(const StateDict& state_dict, | ||
| const std::string& tensor_name, | ||
| int weight_position); | ||
|
|
||
| void set_weight(const StateDict& state_dict, | ||
| const std::string& tensor_name, | ||
| int weight_position, | ||
| int dim); | ||
|
|
||
| void set_weight(const StateDict& state_dict, | ||
| const std::string& tensor_name, | ||
| int weight_position, | ||
| int dim, | ||
| int rank, | ||
| int world_size); | ||
| }; | ||
|
|
||
| } // namespace layer | ||
| } // namespace xllm |
47 changes: 47 additions & 0 deletions
47
xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| /* Copyright 2025 The xLLM Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://github.com/jd-opensource/xllm/blob/main/LICENSE | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| ==============================================================================*/ | ||
|
|
||
| #include "column_parallel_linear_loader.h" | ||
|
|
||
| namespace xllm { | ||
| namespace layer { | ||
|
|
||
| ColumParallelLinearLoader::ColumParallelLinearLoader( | ||
| uint64_t weight_count, | ||
| const ModelContext& context) | ||
| : BaseLoader(weight_count, context) { | ||
| auto options = context.get_tensor_options(); | ||
| dtype_ = torch::typeMetaToScalarType(options.dtype()); | ||
| at_weight_tensors_[0] = torch::zeros({1}).to(options); | ||
| } | ||
|
|
||
| void ColumParallelLinearLoader::load_state_dict(const StateDict& state_dict) { | ||
| if (dp_size_ > 1) { | ||
| set_weight( | ||
| state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_); | ||
| } else { | ||
| set_weight(state_dict, "weight", 0, 0); | ||
| } | ||
| at_weight_tensors_[0] = at_weight_tensors_[0].to(dtype_); | ||
| } | ||
|
|
||
| void ColumParallelLinearLoader::verify_loaded_weights( | ||
| const std::string& weight_str) const { | ||
| CHECK(at_weight_tensors_[0].sizes() != std::vector<int64_t>({1})) | ||
Clement-Wang26 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| << "weight is not loaded for " << weight_str; | ||
| } | ||
|
|
||
| } // namespace layer | ||
| } // namespace xllm | ||
28 changes: 28 additions & 0 deletions
28
xllm/core/layers/npu/loader/column_parallel_linear_loader.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| /* Copyright 2025 The xLLM Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://github.com/jd-opensource/xllm/blob/main/LICENSE | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| ==============================================================================*/ | ||
|
|
||
| #include "base_loader.h" | ||
|
|
||
| namespace xllm { | ||
| namespace layer { | ||
| class ColumParallelLinearLoader : public BaseLoader { | ||
| public: | ||
| ColumParallelLinearLoader(uint64_t weight_count, const ModelContext& context); | ||
|
|
||
| void load_state_dict(const StateDict& state_dict) override; | ||
| void verify_loaded_weights(const std::string& weight_str) const override; | ||
| }; | ||
| } // namespace layer | ||
| } // namespace xllm |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.