@@ -38,8 +38,8 @@ class DeepseekMultiTokenPredictorLayerImpl : public torch::nn::Module {
3838 auto parallel_args = context.get_parallel_args ();
3939
4040 // register submodules
41- enorm_ = register_module (" enorm" , layer::RmsNorm (context));
42- hnorm_ = register_module (" hnorm" , layer::RmsNorm (context));
41+ enorm_ = register_module (" enorm" , layer::RMSNorm (context));
42+ hnorm_ = register_module (" hnorm" , layer::RMSNorm (context));
4343 // no quantization for eh_proj
4444 eh_proj_ =
4545 register_module (" eh_proj" ,
@@ -92,8 +92,8 @@ class DeepseekMultiTokenPredictorLayerImpl : public torch::nn::Module {
9292 virtual void update_expert_weight (int32_t layer_id) { return ; }
9393
9494 private:
95- layer::RmsNorm enorm_{nullptr };
96- layer::RmsNorm hnorm_{nullptr };
95+ layer::RMSNorm enorm_{nullptr };
96+ layer::RMSNorm hnorm_{nullptr };
9797 layer::ReplicatedLinear eh_proj_{nullptr };
9898 layer::DeepseekV2DecoderLayer mtp_block_{nullptr };
9999};
@@ -125,7 +125,7 @@ class DeepseekMTPModelImpl : public torch::nn::Module {
125125 model_args.hidden_size (),
126126 context.get_parallel_args (),
127127 options));
128- norm_ = register_module (" norm" , layer::RmsNorm (context));
128+ norm_ = register_module (" norm" , layer::RMSNorm (context));
129129
130130 // get dp size and rank
131131 dp_size_ = parallel_args.dp_size ();
@@ -196,7 +196,7 @@ class DeepseekMTPModelImpl : public torch::nn::Module {
196196 int32_t dp_size_;
197197 int32_t dp_local_tp_size_;
198198 layer::WordEmbedding embed_tokens_{nullptr };
199- layer::RmsNorm norm_{nullptr };
199+ layer::RMSNorm norm_{nullptr };
200200};
201201TORCH_MODULE (DeepseekMTPModel);
202202
0 commit comments