Skip to content

Commit fe6ad67

Browse files
DongheJinyq33victor
authored andcommitted
bugfix: resolve multi-machine communication domain error.
1 parent d94d1b3 commit fe6ad67

File tree

3 files changed

+37
-14
lines changed

3 files changed

+37
-14
lines changed

xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ void Qwen3DecoderLayerImpl::param_from_args(
7070
param.enableIntraLayerAddNorm = true;
7171
param.enableInterLayerAddNorm = false;
7272
param.enablePreFetchWeight = FLAGS_enable_prefetch_weight;
73+
initialize_parallel_parameters(param, parallel_args);
7374
initialize_quantization_parameters(param);
7475

7576
if (isPrefill) {
@@ -89,6 +90,18 @@ void Qwen3DecoderLayerImpl::param_from_args(
8990
}
9091
}
9192

93+
void Qwen3DecoderLayerImpl::initialize_parallel_parameters(
94+
atb_speed::qwen::QwenLayerParam& param,
95+
const ParallelArgs& parallel_args) {
96+
param.mapping = parallel_args.mapping();
97+
param.tensorParallelInfo = {parallel_args.rank(),
98+
parallel_args.world_size(),
99+
FLAGS_communication_backend,
100+
FLAGS_rank_tablefile,
101+
nullptr,
102+
""};
103+
}
104+
92105
void Qwen3DecoderLayerImpl::initialize_quantization_parameters(
93106
atb_speed::qwen::QwenLayerParam& param) {
94107
if (quantize_type_.empty()) {

xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class Qwen3DecoderLayerImpl : public BaseLayer {
7878
ModelInputParams& input_params,
7979
bool is_prefill);
8080

81+
void initialize_parallel_parameters(atb_speed::qwen::QwenLayerParam& param,
82+
const ParallelArgs& parallel_args);
83+
8184
void initialize_quantization_parameters(
8285
atb_speed::qwen::QwenLayerParam& param);
8386

xllm/core/layers/npu/npu_word_embedding_impl.cpp

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,28 @@ void WordEmbeddingImpl::param_from_args(
2727
const xllm::ModelArgs& args,
2828
const xllm::ParallelArgs& parallel_args) {
2929
param.unpadInputs = true;
30-
if (dp_size_ > 1) {
31-
param.tensorParallelInfo.rank = dp_local_tp_rank_;
32-
param.tensorParallelInfo.worldSize = dp_local_tp_size_;
33-
param.tensorParallelInfo.backend = FLAGS_communication_backend;
34-
} else if (parallel_args.world_size() != 1) {
35-
// param.tensorParallelInfo = {parallel_args.rank(),
36-
// parallel_args.world_size(), "lccl"};
37-
param.tensorParallelInfo = {parallel_args.rank(),
38-
parallel_args.world_size(),
39-
FLAGS_communication_backend};
30+
31+
if (parallel_args.world_size() > 1) {
32+
if (parallel_args.mapping_data().empty()) {
33+
if (dp_size_ > 1) {
34+
param.tensorParallelInfo.rank = dp_local_tp_rank_;
35+
param.tensorParallelInfo.worldSize = dp_local_tp_size_;
36+
} else {
37+
param.tensorParallelInfo.rank = parallel_args.rank();
38+
param.tensorParallelInfo.worldSize = parallel_args.world_size();
39+
}
40+
param.tensorParallelInfo.commDomain = std::to_string(dp_rank_);
41+
param.tensorParallelInfo.backend = FLAGS_communication_backend;
42+
} else {
43+
atb_speed::common::ParallelInfo parallelInfo =
44+
parallel_args.mapping().Get(atb_speed::base::ATTN_TP);
45+
param.tensorParallelInfo.rank = parallelInfo.rank;
46+
param.tensorParallelInfo.worldSize = parallelInfo.rankIds.size();
47+
param.tensorParallelInfo.backend = FLAGS_communication_backend;
48+
parallelInfo.InitCommDomain(param.tensorParallelInfo.hcommInfo,
49+
param.tensorParallelInfo.commDomain);
50+
}
4051
}
41-
// param.linearParallelParam.tensorParallelInfo.backend =
42-
// FLAGS_communication_backend;
43-
param.tensorParallelInfo.commDomain = std::to_string(dp_rank_);
44-
// param.tensorParallelInfo.rankTableFile = FLAGS_rank_tablefile;
4552
}
4653

4754
WordEmbeddingImpl::WordEmbeddingImpl(const ModelContext& context)

0 commit comments

Comments
 (0)