Skip to content

Commit 8c2401a

Browse files
committed
refactor: separate mlu and cuda version Qwen model implementation.
1 parent 8a2110c commit 8c2401a

34 files changed

+3542
-991
lines changed

xllm/models/llm/mlu/deepseek_mtp.h renamed to xllm/models/llm/deepseek_mtp.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ limitations under the License.
2121
#include <vector>
2222

2323
#include "core/layers/deepseek_v2_decoder_layer.h"
24-
#include "models/llm/llm_model_base.h"
24+
#include "llm_model_base.h"
2525

2626
// DeepSeek v2 compatible with huggingface weights
2727
// ref to:

xllm/models/llm/deepseek_v2.h

Lines changed: 50 additions & 214 deletions
Original file line numberDiff line numberDiff line change
@@ -12,79 +12,48 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
15-
1615
#pragma once
1716

18-
#include <gflags/gflags.h>
1917
#include <torch/torch.h>
2018

21-
#include <boost/algorithm/string.hpp>
2219
#include <string>
2320
#include <vector>
2421

25-
#include "core/common/global_flags.h"
26-
#include "core/framework/kv_cache/kv_cache.h"
27-
#include "core/framework/model/model_input_params.h"
28-
#include "core/framework/model/npu_dp_ep_padding.h"
29-
#include "core/framework/model_context.h"
30-
#include "core/layers/common/attention_mask_impl.h"
3122
#include "core/layers/deepseek_v2_decoder_layer.h"
32-
#include "core/layers/lm_head.h"
33-
#include "core/layers/npu/npu_rms_norm_impl.h"
34-
#include "core/layers/npu/rotary_embedding.h"
35-
#include "core/layers/pos_embedding.h"
36-
#include "core/layers/word_embedding.h"
37-
#include "models/model_registry.h"
23+
#include "llm_model_base.h"
24+
3825
// DeepSeek v2 compatible with huggingface weights
3926
// ref to:
4027
// https://github.com/vllm-project/vllm/blob/v0.6.6/vllm/model_executor/models/deepseek_v2.py
4128

4229
namespace xllm {
4330

44-
using torch::indexing::None;
45-
using ISlice = torch::indexing::Slice;
46-
4731
class DeepseekV2DecoderLayerImpl : public torch::nn::Module {
4832
public:
49-
DeepseekV2DecoderLayerImpl(const ModelContext& context, const int32_t i) {
33+
DeepseekV2DecoderLayerImpl(const ModelContext& context,
34+
const int32_t layer_index) {
5035
// register submodules
51-
decoder_layer_ = register_module("decoder_layer",
52-
layer::DeepseekV2DecoderLayer(context, i));
36+
decoder_layer_ = register_module(
37+
"decoder_layer", layer::DeepseekV2DecoderLayer(context, layer_index));
5338
}
5439

5540
torch::Tensor forward(torch::Tensor& x,
56-
torch::Tensor& cos_pos,
57-
torch::Tensor& sin_pos,
58-
torch::Tensor& attn_mask,
41+
torch::Tensor& positions,
42+
const layer::AttentionMetadata& attn_metadata,
5943
KVCache& kv_cache,
60-
const ModelInputParams& input_params,
61-
aclrtEvent* event,
62-
std::atomic<bool>* event_flag) {
63-
return decoder_layer_(x,
64-
cos_pos,
65-
sin_pos,
66-
attn_mask,
67-
kv_cache,
68-
input_params,
69-
event,
70-
event_flag);
44+
const ModelInputParams& input_params) {
45+
return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params);
7146
}
7247

7348
void load_state_dict(const StateDict& state_dict) {
7449
decoder_layer_->load_state_dict(state_dict);
7550
}
7651

77-
void verify_loaded_weights(const std::string& prefix) const {
78-
decoder_layer_->verify_loaded_weights(prefix);
52+
virtual void prepare_expert_weight(int32_t layer_id,
53+
const std::vector<int32_t>& expert_ids) {
54+
return;
7955
}
80-
81-
void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); }
82-
83-
void prepare_expert_weight(const std::vector<int32_t>& expert_list) {
84-
decoder_layer_->prepare_expert_weight(expert_list);
85-
}
86-
87-
void update_expert_weight() { decoder_layer_->update_expert_weight(); }
56+
virtual void update_expert_weight(int32_t layer_id) { return; }
8857

8958
private:
9059
layer::DeepseekV2DecoderLayer decoder_layer_{nullptr};
@@ -93,102 +62,64 @@ TORCH_MODULE(DeepseekV2DecoderLayer);
9362

9463
class DeepseekV2ModelImpl : public torch::nn::Module {
9564
public:
96-
DeepseekV2ModelImpl(const ModelContext& context)
97-
: device_(context.get_tensor_options().device()) {
65+
DeepseekV2ModelImpl(const ModelContext& context) {
9866
auto options = context.get_tensor_options();
9967
auto model_args = context.get_model_args();
10068
auto parallel_args = context.get_parallel_args();
10169

10270
blocks_ = register_module("layers", torch::nn::ModuleList());
10371
layers_.reserve(model_args.n_layers());
104-
// register submodules
105-
device_ = options.device();
106-
dtype_ = options.dtype().toScalarType();
107-
num_speculative_tokens_ = model_args.num_speculative_tokens();
10872

10973
embed_tokens_ =
110-
register_module("embed_tokens", layer::WordEmbedding(context));
111-
pos_emb_ = create_rotary_embedding(model_args,
112-
model_args.rotary_dim(),
113-
/*interleaved=*/false,
114-
options);
115-
atb_pos_emb_ = layer::PosEmbedding(context);
116-
117-
max_seq_len_ = model_args.max_position_embeddings();
118-
int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984;
119-
attn_mask_ = layer::AttentionMask(options.device(),
120-
options.dtype().toScalarType(),
121-
/*mask_value=*/mask_value);
122-
74+
register_module("embed_tokens",
75+
layer::WordEmbedding(model_args.vocab_size(),
76+
model_args.hidden_size(),
77+
context.get_parallel_args(),
78+
options));
79+
norm_ = register_module(
80+
"norm",
81+
layer::RMSNorm(
82+
model_args.hidden_size(), model_args.rms_norm_eps(), options));
83+
84+
// create decoder layers
12385
for (int32_t i = 0; i < model_args.n_layers(); ++i) {
12486
auto block = DeepseekV2DecoderLayer(context, i);
12587
layers_.push_back(block);
12688
blocks_->push_back(block);
12789
}
12890

129-
norm_ = register_module("norm", layer::RMSNorm(context));
130-
// dp_size_=4;
13191
dp_size_ = parallel_args.dp_size();
13292
std::vector<int64_t> indices;
13393
dp_local_tp_size_ = parallel_args.world_size() / dp_size_;
13494
dp_rank_ = parallel_args.rank() / dp_local_tp_size_;
13595
rank_ = parallel_args.rank();
136-
mapping_data_ = parallel_args.mapping_data();
137-
num_experts_per_tok_ = model_args.num_experts_per_tok();
13896
for (int i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) {
13997
indices.push_back(i);
14098
}
14199
}
142100

143-
torch::Tensor forward(torch::Tensor tokens,
144-
torch::Tensor positions,
145-
std::vector<KVCache>& kv_caches,
146-
const ModelInputParams& input_params) {
147-
if (dp_size_ > 1) {
148-
if (tokens.sizes() == 0) {
149-
tokens = torch::tensor({1}).to(torch::kInt32).to(device_);
150-
positions = torch::tensor({0}).to(torch::kInt32).to(device_);
151-
}
152-
}
153-
154-
auto h = embed_tokens_(tokens, 0);
155-
auto cos_sin = atb_pos_emb_(pos_emb_->get_cos_sin_cache(), positions, 0);
156-
auto cos_sin_chunks = cos_sin.chunk(/*chunks=*/2, /*dim=*/-1);
157-
auto cos_pos = cos_sin_chunks[0].contiguous();
158-
auto sin_pos = cos_sin_chunks[1].contiguous();
159-
160-
torch::Tensor attn_mask;
161-
if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) {
162-
attn_mask = attn_mask_->get_attn_mask(128, dtype_, device_);
163-
} else {
164-
attn_mask = attn_mask_->gen_free_mask(
165-
num_speculative_tokens_ + 1, dtype_, device_);
166-
}
167-
101+
torch::Tensor forward_native(torch::Tensor tokens,
102+
torch::Tensor positions,
103+
std::vector<KVCache>& kv_caches,
104+
const ModelInputParams& input_params) {
105+
bool is_prefill = input_params.q_max_seq_len > 1;
106+
auto attn_metadata =
107+
layer::AttentionMetadata::build(input_params, is_prefill);
108+
torch::Tensor hidden_states = embed_tokens_(tokens);
168109
for (size_t i = 0; i < layers_.size(); i++) {
169-
aclrtEvent* event = nullptr;
170-
std::atomic<bool>* event_flag = nullptr;
171-
if (input_params.layer_synchronizer != nullptr) {
172-
event = input_params.layer_synchronizer->get_event(i);
173-
event_flag = input_params.layer_synchronizer->get_event_flag(i);
174-
}
175-
if (input_params.layer_wise_load_synchronizer != nullptr) {
176-
if (!input_params.layer_wise_load_synchronizer->synchronize_layer(i)) {
177-
return torch::Tensor();
178-
}
179-
}
180-
181110
auto& layer = layers_[i];
182-
layer(h,
183-
cos_pos,
184-
sin_pos,
185-
attn_mask,
186-
kv_caches[i],
187-
input_params,
188-
event,
189-
event_flag);
111+
hidden_states = layer(
112+
hidden_states, positions, attn_metadata, kv_caches[i], input_params);
190113
}
191-
return norm_(h, 0);
114+
return norm_(hidden_states);
115+
}
116+
117+
// Provide batched signature to satisfy callers that pass vectors
118+
torch::Tensor forward(const torch::Tensor& tokens,
119+
const torch::Tensor& positions,
120+
std::vector<KVCache>& kv_caches,
121+
const ModelInputParams& input_params) {
122+
return forward_native(tokens, positions, kv_caches, input_params);
192123
}
193124

194125
// load the weight from the checkpoint
@@ -203,32 +134,6 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
203134
norm_->load_state_dict(state_dict.get_dict_with_prefix("norm."));
204135
}
205136

206-
void verify_loaded_weights(const std::string& prefix) const {
207-
embed_tokens_->verify_loaded_weights(prefix + "embed_tokens.");
208-
for (int i = 0; i < layers_.size(); i++) {
209-
layers_[i]->verify_loaded_weights(prefix + "layers." + std::to_string(i) +
210-
".");
211-
}
212-
norm_->verify_loaded_weights(prefix + "norm.");
213-
}
214-
215-
void merge_loaded_weights() {
216-
embed_tokens_->merge_loaded_weights();
217-
for (int i = 0; i < layers_.size(); i++) {
218-
layers_[i]->merge_loaded_weights();
219-
}
220-
norm_->merge_loaded_weights();
221-
}
222-
223-
void prepare_expert_weight(int32_t layer_id,
224-
const std::vector<int32_t>& expert_ids) {
225-
layers_[layer_id]->prepare_expert_weight(expert_ids);
226-
}
227-
228-
void update_expert_weight(int32_t layer_id) {
229-
layers_[layer_id]->update_expert_weight();
230-
}
231-
232137
layer::WordEmbedding get_word_embedding() { return embed_tokens_; }
233138

234139
void set_word_embedding(layer::WordEmbedding& word_embedding) {
@@ -238,90 +143,20 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
238143
private:
239144
torch::nn::ModuleList blocks_{nullptr};
240145
std::vector<DeepseekV2DecoderLayer> layers_;
241-
int32_t max_seq_len_ = 0;
242146
int32_t dp_rank_;
243147
int32_t rank_;
244148
int32_t dp_size_;
245149
int32_t dp_local_tp_size_;
246-
nlohmann::json mapping_data_;
247-
int32_t num_experts_per_tok_;
248-
int32_t num_speculative_tokens_ = 0;
249-
at::Device device_;
250-
torch::Dtype dtype_;
251150
layer::WordEmbedding embed_tokens_{nullptr};
252-
std::shared_ptr<RotaryEmbedding> pos_emb_{nullptr};
253-
layer::PosEmbedding atb_pos_emb_{nullptr};
254-
layer::AttentionMask attn_mask_{nullptr};
255151
layer::RMSNorm norm_{nullptr};
256152
};
257153
TORCH_MODULE(DeepseekV2Model);
258154

259-
class DeepseekV2ForCausalLMImpl : public torch::nn::Module {
155+
class DeepseekV2ForCausalLMImpl
156+
: public LlmForCausalLMImplBase<DeepseekV2Model> {
260157
public:
261-
DeepseekV2ForCausalLMImpl(const ModelContext& context) {
262-
model_ = register_module("model", DeepseekV2Model(context));
263-
lm_head_ = register_module("lm_head", layer::LmHead(context));
264-
first_k_dense_replace_ = context.get_model_args().first_k_dense_replace();
265-
}
266-
267-
// tokens: [num_tokens]
268-
// positions: [num_tokens] token pos in the sequence
269-
// returns: [num_tokens, hidden_size]
270-
torch::Tensor forward(const torch::Tensor& tokens,
271-
const torch::Tensor& positions,
272-
std::vector<KVCache>& kv_caches,
273-
const ModelInputParams& input_params) {
274-
return model_(tokens, positions, kv_caches, input_params);
275-
}
276-
277-
// hidden_states: [num_tokens, hidden_size]
278-
// seleted_idxes: [num_tokens]
279-
// returns: [num_tokens, vocab_size]
280-
torch::Tensor logits(const torch::Tensor& hidden_states,
281-
const torch::Tensor& seleted_idxes) {
282-
return lm_head_(hidden_states, seleted_idxes, 0);
283-
}
284-
285-
void load_model(std::unique_ptr<ModelLoader> loader) {
286-
for (const auto& state_dict : loader->get_state_dicts()) {
287-
model_->load_state_dict(state_dict->get_dict_with_prefix("model."));
288-
lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head."));
289-
}
290-
291-
// verify
292-
model_->verify_loaded_weights("model.");
293-
lm_head_->verify_loaded_weights("lm_head.");
294-
295-
model_->merge_loaded_weights();
296-
lm_head_->merge_loaded_weights();
297-
}
298-
299-
void prepare_expert_weight(int32_t layer_id,
300-
const std::vector<int32_t>& expert_ids) {
301-
model_->prepare_expert_weight(layer_id + first_k_dense_replace_,
302-
expert_ids);
303-
}
304-
305-
void update_expert_weight(int32_t layer_id) {
306-
model_->update_expert_weight(layer_id + first_k_dense_replace_);
307-
}
308-
309-
layer::LmHead get_lm_head() { return lm_head_; }
310-
311-
void set_lm_head(layer::LmHead& head) { lm_head_ = head; }
312-
313-
layer::WordEmbedding get_word_embedding() {
314-
return model_->get_word_embedding();
315-
}
316-
317-
void set_word_embedding(layer::WordEmbedding& word_embedding) {
318-
model_->set_word_embedding(word_embedding);
319-
}
320-
321-
private:
322-
DeepseekV2Model model_{nullptr};
323-
layer::LmHead lm_head_{nullptr};
324-
int32_t first_k_dense_replace_;
158+
DeepseekV2ForCausalLMImpl(const ModelContext& context)
159+
: LlmForCausalLMImplBase<DeepseekV2Model>(context) {}
325160
};
326161
TORCH_MODULE(DeepseekV2ForCausalLM);
327162

@@ -365,6 +200,7 @@ REGISTER_MODEL_ARGS(deepseek_v2, [&] {
365200
LOAD_ARG_OR(v_head_dim, "v_head_dim", 128);
366201
LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 0);
367202
LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512);
203+
LOAD_ARG_OR(num_nextn_predict_layers, "num_nextn_predict_layers", 1);
368204

369205
LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
370206
return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim();

xllm/models/llm/deepseek_v3.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ REGISTER_MODEL_ARGS(deepseek_v3, [&] {
4242
LOAD_ARG_OR(max_window_layers, "max_window_layers", 61);
4343

4444
LOAD_ARG_OR(first_k_dense_replace, "first_k_dense_replace", 3);
45+
LOAD_ARG_OR(hidden_act, "hidden_act", "silu");
4546
LOAD_ARG_OR(moe_layer_freq, "moe_layer_freq", 1);
4647
LOAD_ARG_OR(topk_method, "topk_method", "noaux_tc");
4748
LOAD_ARG_OR(n_routed_experts, "n_routed_experts", 256);
@@ -52,11 +53,13 @@ REGISTER_MODEL_ARGS(deepseek_v3, [&] {
5253
LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true);
5354
LOAD_ARG_OR(n_group, "n_group", 8);
5455
LOAD_ARG_OR(topk_group, "topk_group", 4);
56+
LOAD_ARG_OR(scoring_func, "scoring_func", "sigmoid");
5557
LOAD_ARG_OR(qk_nope_head_dim, "qk_nope_head_dim", 128);
5658
LOAD_ARG_OR(qk_rope_head_dim, "qk_rope_head_dim", 64);
5759
LOAD_ARG_OR(v_head_dim, "v_head_dim", 128);
5860
LOAD_ARG_OR(q_lora_rank, "q_lora_rank", 1536);
5961
LOAD_ARG_OR(kv_lora_rank, "kv_lora_rank", 512);
62+
LOAD_ARG_OR(num_nextn_predict_layers, "num_nextn_predict_layers", 1);
6063

6164
LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
6265
return 256; // args->qk_nope_head_dim() + args->qk_rope_head_dim();
File renamed without changes.

0 commit comments

Comments
 (0)