@@ -12,79 +12,48 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
15-
1615#pragma once
1716
18- #include < gflags/gflags.h>
1917#include < torch/torch.h>
2018
21- #include < boost/algorithm/string.hpp>
2219#include < string>
2320#include < vector>
2421
25- #include " core/common/global_flags.h"
26- #include " core/framework/kv_cache/kv_cache.h"
27- #include " core/framework/model/model_input_params.h"
28- #include " core/framework/model/npu_dp_ep_padding.h"
29- #include " core/framework/model_context.h"
30- #include " core/layers/common/attention_mask_impl.h"
3122#include " core/layers/deepseek_v2_decoder_layer.h"
32- #include " core/layers/lm_head.h"
33- #include " core/layers/npu/npu_rms_norm_impl.h"
34- #include " core/layers/npu/rotary_embedding.h"
35- #include " core/layers/pos_embedding.h"
36- #include " core/layers/word_embedding.h"
37- #include " models/model_registry.h"
23+ #include " llm_model_base.h"
24+
3825// DeepSeek v2 compatible with huggingface weights
3926// ref to:
4027// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py
4128
4229namespace xllm {
4330
44- using torch::indexing::None;
45- using ISlice = torch::indexing::Slice;
46-
4731class DeepseekV2DecoderLayerImpl : public torch ::nn::Module {
4832 public:
49- DeepseekV2DecoderLayerImpl (const ModelContext& context, const int32_t i) {
33+ DeepseekV2DecoderLayerImpl (const ModelContext& context,
34+ const int32_t layer_index) {
5035 // register submodules
51- decoder_layer_ = register_module (" decoder_layer " ,
52- layer::DeepseekV2DecoderLayer (context, i ));
36+ decoder_layer_ = register_module (
37+ " decoder_layer " , layer::DeepseekV2DecoderLayer (context, layer_index ));
5338 }
5439
5540 torch::Tensor forward (torch::Tensor& x,
56- torch::Tensor& cos_pos,
57- torch::Tensor& sin_pos,
58- torch::Tensor& attn_mask,
41+ torch::Tensor& positions,
42+ const layer::AttentionMetadata& attn_metadata,
5943 KVCache& kv_cache,
60- const ModelInputParams& input_params,
61- aclrtEvent* event,
62- std::atomic<bool >* event_flag) {
63- return decoder_layer_ (x,
64- cos_pos,
65- sin_pos,
66- attn_mask,
67- kv_cache,
68- input_params,
69- event,
70- event_flag);
44+ const ModelInputParams& input_params) {
45+ return decoder_layer_ (x, positions, attn_metadata, kv_cache, input_params);
7146 }
7247
7348 void load_state_dict (const StateDict& state_dict) {
7449 decoder_layer_->load_state_dict (state_dict);
7550 }
7651
77- void verify_loaded_weights (const std::string& prefix) const {
78- decoder_layer_->verify_loaded_weights (prefix);
52+ virtual void prepare_expert_weight (int32_t layer_id,
53+ const std::vector<int32_t >& expert_ids) {
54+ return ;
7955 }
80-
81- void merge_loaded_weights () { decoder_layer_->merge_loaded_weights (); }
82-
83- void prepare_expert_weight (const std::vector<int32_t >& expert_list) {
84- decoder_layer_->prepare_expert_weight (expert_list);
85- }
86-
87- void update_expert_weight () { decoder_layer_->update_expert_weight (); }
56+ virtual void update_expert_weight (int32_t layer_id) { return ; }
8857
8958 private:
9059 layer::DeepseekV2DecoderLayer decoder_layer_{nullptr };
@@ -93,102 +62,64 @@ TORCH_MODULE(DeepseekV2DecoderLayer);
9362
9463class DeepseekV2ModelImpl : public torch ::nn::Module {
9564 public:
96- DeepseekV2ModelImpl (const ModelContext& context)
97- : device_(context.get_tensor_options().device()) {
65+ DeepseekV2ModelImpl (const ModelContext& context) {
9866 auto options = context.get_tensor_options ();
9967 auto model_args = context.get_model_args ();
10068 auto parallel_args = context.get_parallel_args ();
10169
10270 blocks_ = register_module (" layers" , torch::nn::ModuleList ());
10371 layers_.reserve (model_args.n_layers ());
104- // register submodules
105- device_ = options.device ();
106- dtype_ = options.dtype ().toScalarType ();
107- num_speculative_tokens_ = model_args.num_speculative_tokens ();
10872
10973 embed_tokens_ =
110- register_module (" embed_tokens" , layer::WordEmbedding (context));
111- pos_emb_ = create_rotary_embedding (model_args,
112- model_args.rotary_dim (),
113- /* interleaved=*/ false ,
114- options);
115- atb_pos_emb_ = layer::PosEmbedding (context);
116-
117- max_seq_len_ = model_args.max_position_embeddings ();
118- int32_t mask_value = model_args.dtype () == " bfloat16" ? 1 : -9984 ;
119- attn_mask_ = layer::AttentionMask (options.device (),
120- options.dtype ().toScalarType (),
121- /* mask_value=*/ mask_value);
122-
74+ register_module (" embed_tokens" ,
75+ layer::WordEmbedding (model_args.vocab_size (),
76+ model_args.hidden_size (),
77+ context.get_parallel_args (),
78+ options));
79+ norm_ = register_module (
80+ " norm" ,
81+ layer::RMSNorm (
82+ model_args.hidden_size (), model_args.rms_norm_eps (), options));
83+
84+ // create decoder layers
12385 for (int32_t i = 0 ; i < model_args.n_layers (); ++i) {
12486 auto block = DeepseekV2DecoderLayer (context, i);
12587 layers_.push_back (block);
12688 blocks_->push_back (block);
12789 }
12890
129- norm_ = register_module (" norm" , layer::RMSNorm (context));
130- // dp_size_=4;
13191 dp_size_ = parallel_args.dp_size ();
13292 std::vector<int64_t > indices;
13393 dp_local_tp_size_ = parallel_args.world_size () / dp_size_;
13494 dp_rank_ = parallel_args.rank () / dp_local_tp_size_;
13595 rank_ = parallel_args.rank ();
136- mapping_data_ = parallel_args.mapping_data ();
137- num_experts_per_tok_ = model_args.num_experts_per_tok ();
13896 for (int i = 0 ; i < parallel_args.world_size (); i += dp_local_tp_size_) {
13997 indices.push_back (i);
14098 }
14199 }
142100
143- torch::Tensor forward (torch::Tensor tokens,
144- torch::Tensor positions,
145- std::vector<KVCache>& kv_caches,
146- const ModelInputParams& input_params) {
147- if (dp_size_ > 1 ) {
148- if (tokens.sizes () == 0 ) {
149- tokens = torch::tensor ({1 }).to (torch::kInt32 ).to (device_);
150- positions = torch::tensor ({0 }).to (torch::kInt32 ).to (device_);
151- }
152- }
153-
154- auto h = embed_tokens_ (tokens, 0 );
155- auto cos_sin = atb_pos_emb_ (pos_emb_->get_cos_sin_cache (), positions, 0 );
156- auto cos_sin_chunks = cos_sin.chunk (/* chunks=*/ 2 , /* dim=*/ -1 );
157- auto cos_pos = cos_sin_chunks[0 ].contiguous ();
158- auto sin_pos = cos_sin_chunks[1 ].contiguous ();
159-
160- torch::Tensor attn_mask;
161- if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache ) {
162- attn_mask = attn_mask_->get_attn_mask (128 , dtype_, device_);
163- } else {
164- attn_mask = attn_mask_->gen_free_mask (
165- num_speculative_tokens_ + 1 , dtype_, device_);
166- }
167-
101+ torch::Tensor forward_native (torch::Tensor tokens,
102+ torch::Tensor positions,
103+ std::vector<KVCache>& kv_caches,
104+ const ModelInputParams& input_params) {
105+ bool is_prefill = input_params.q_max_seq_len > 1 ;
106+ auto attn_metadata =
107+ layer::AttentionMetadata::build (input_params, is_prefill);
108+ torch::Tensor hidden_states = embed_tokens_ (tokens);
168109 for (size_t i = 0 ; i < layers_.size (); i++) {
169- aclrtEvent* event = nullptr ;
170- std::atomic<bool >* event_flag = nullptr ;
171- if (input_params.layer_synchronizer != nullptr ) {
172- event = input_params.layer_synchronizer ->get_event (i);
173- event_flag = input_params.layer_synchronizer ->get_event_flag (i);
174- }
175- if (input_params.layer_wise_load_synchronizer != nullptr ) {
176- if (!input_params.layer_wise_load_synchronizer ->synchronize_layer (i)) {
177- return torch::Tensor ();
178- }
179- }
180-
181110 auto & layer = layers_[i];
182- layer (h,
183- cos_pos,
184- sin_pos,
185- attn_mask,
186- kv_caches[i],
187- input_params,
188- event,
189- event_flag);
111+ hidden_states = layer (
112+ hidden_states, positions, attn_metadata, kv_caches[i], input_params);
190113 }
191- return norm_ (h, 0 );
114+ return norm_ (hidden_states);
115+ }
116+
117+ // Provide batched signature to satisfy callers that pass vectors
118+ torch::Tensor forward (const torch::Tensor& tokens,
119+ const torch::Tensor& positions,
120+ std::vector<KVCache>& kv_caches,
121+ const ModelInputParams& input_params) {
122+ return forward_native (tokens, positions, kv_caches, input_params);
192123 }
193124
194125 // load the weight from the checkpoint
@@ -203,32 +134,6 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
203134 norm_->load_state_dict (state_dict.get_dict_with_prefix (" norm." ));
204135 }
205136
206- void verify_loaded_weights (const std::string& prefix) const {
207- embed_tokens_->verify_loaded_weights (prefix + " embed_tokens." );
208- for (int i = 0 ; i < layers_.size (); i++) {
209- layers_[i]->verify_loaded_weights (prefix + " layers." + std::to_string (i) +
210- " ." );
211- }
212- norm_->verify_loaded_weights (prefix + " norm." );
213- }
214-
215- void merge_loaded_weights () {
216- embed_tokens_->merge_loaded_weights ();
217- for (int i = 0 ; i < layers_.size (); i++) {
218- layers_[i]->merge_loaded_weights ();
219- }
220- norm_->merge_loaded_weights ();
221- }
222-
223- void prepare_expert_weight (int32_t layer_id,
224- const std::vector<int32_t >& expert_ids) {
225- layers_[layer_id]->prepare_expert_weight (expert_ids);
226- }
227-
228- void update_expert_weight (int32_t layer_id) {
229- layers_[layer_id]->update_expert_weight ();
230- }
231-
232137 layer::WordEmbedding get_word_embedding () { return embed_tokens_; }
233138
234139 void set_word_embedding (layer::WordEmbedding& word_embedding) {
@@ -238,90 +143,20 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
238143 private:
239144 torch::nn::ModuleList blocks_{nullptr };
240145 std::vector<DeepseekV2DecoderLayer> layers_;
241- int32_t max_seq_len_ = 0 ;
242146 int32_t dp_rank_;
243147 int32_t rank_;
244148 int32_t dp_size_;
245149 int32_t dp_local_tp_size_;
246- nlohmann::json mapping_data_;
247- int32_t num_experts_per_tok_;
248- int32_t num_speculative_tokens_ = 0 ;
249- at::Device device_;
250- torch::Dtype dtype_;
251150 layer::WordEmbedding embed_tokens_{nullptr };
252- std::shared_ptr<RotaryEmbedding> pos_emb_{nullptr };
253- layer::PosEmbedding atb_pos_emb_{nullptr };
254- layer::AttentionMask attn_mask_{nullptr };
255151 layer::RMSNorm norm_{nullptr };
256152};
257153TORCH_MODULE (DeepseekV2Model);
258154
259- class DeepseekV2ForCausalLMImpl : public torch ::nn::Module {
155+ class DeepseekV2ForCausalLMImpl
156+ : public LlmForCausalLMImplBase<DeepseekV2Model> {
260157 public:
261- DeepseekV2ForCausalLMImpl (const ModelContext& context) {
262- model_ = register_module (" model" , DeepseekV2Model (context));
263- lm_head_ = register_module (" lm_head" , layer::LmHead (context));
264- first_k_dense_replace_ = context.get_model_args ().first_k_dense_replace ();
265- }
266-
267- // tokens: [num_tokens]
268- // positions: [num_tokens] token pos in the sequence
269- // returns: [num_tokens, hidden_size]
270- torch::Tensor forward (const torch::Tensor& tokens,
271- const torch::Tensor& positions,
272- std::vector<KVCache>& kv_caches,
273- const ModelInputParams& input_params) {
274- return model_ (tokens, positions, kv_caches, input_params);
275- }
276-
277- // hidden_states: [num_tokens, hidden_size]
278- // seleted_idxes: [num_tokens]
279- // returns: [num_tokens, vocab_size]
280- torch::Tensor logits (const torch::Tensor& hidden_states,
281- const torch::Tensor& seleted_idxes) {
282- return lm_head_ (hidden_states, seleted_idxes, 0 );
283- }
284-
285- void load_model (std::unique_ptr<ModelLoader> loader) {
286- for (const auto & state_dict : loader->get_state_dicts ()) {
287- model_->load_state_dict (state_dict->get_dict_with_prefix (" model." ));
288- lm_head_->load_state_dict (state_dict->get_dict_with_prefix (" lm_head." ));
289- }
290-
291- // verify
292- model_->verify_loaded_weights (" model." );
293- lm_head_->verify_loaded_weights (" lm_head." );
294-
295- model_->merge_loaded_weights ();
296- lm_head_->merge_loaded_weights ();
297- }
298-
299- void prepare_expert_weight (int32_t layer_id,
300- const std::vector<int32_t >& expert_ids) {
301- model_->prepare_expert_weight (layer_id + first_k_dense_replace_,
302- expert_ids);
303- }
304-
305- void update_expert_weight (int32_t layer_id) {
306- model_->update_expert_weight (layer_id + first_k_dense_replace_);
307- }
308-
309- layer::LmHead get_lm_head () { return lm_head_; }
310-
311- void set_lm_head (layer::LmHead& head) { lm_head_ = head; }
312-
313- layer::WordEmbedding get_word_embedding () {
314- return model_->get_word_embedding ();
315- }
316-
317- void set_word_embedding (layer::WordEmbedding& word_embedding) {
318- model_->set_word_embedding (word_embedding);
319- }
320-
321- private:
322- DeepseekV2Model model_{nullptr };
323- layer::LmHead lm_head_{nullptr };
324- int32_t first_k_dense_replace_;
158+ DeepseekV2ForCausalLMImpl (const ModelContext& context)
159+ : LlmForCausalLMImplBase<DeepseekV2Model>(context) {}
325160};
326161TORCH_MODULE (DeepseekV2ForCausalLM);
327162
@@ -365,6 +200,7 @@ REGISTER_MODEL_ARGS(deepseek_v2, [&] {
365200 LOAD_ARG_OR (v_head_dim, " v_head_dim" , 128 );
366201 LOAD_ARG_OR (q_lora_rank, " q_lora_rank" , 0 );
367202 LOAD_ARG_OR (kv_lora_rank, " kv_lora_rank" , 512 );
203+ LOAD_ARG_OR (num_nextn_predict_layers, " num_nextn_predict_layers" , 1 );
368204
369205 LOAD_ARG_OR_FUNC (head_dim, " head_dim" , [&] {
370206 return 256 ; // args->qk_nope_head_dim() + args->qk_rope_head_dim();
0 commit comments