@@ -27,21 +27,28 @@ void WordEmbeddingImpl::param_from_args(
2727 const xllm::ModelArgs& args,
2828 const xllm::ParallelArgs& parallel_args) {
2929 param.unpadInputs = true ;
30- if (dp_size_ > 1 ) {
31- param.tensorParallelInfo .rank = dp_local_tp_rank_;
32- param.tensorParallelInfo .worldSize = dp_local_tp_size_;
33- param.tensorParallelInfo .backend = FLAGS_communication_backend;
34- } else if (parallel_args.world_size () != 1 ) {
35- // param.tensorParallelInfo = {parallel_args.rank(),
36- // parallel_args.world_size(), "lccl"};
37- param.tensorParallelInfo = {parallel_args.rank (),
38- parallel_args.world_size (),
39- FLAGS_communication_backend};
30+
31+ if (parallel_args.world_size () > 1 ) {
32+ if (parallel_args.mapping_data ().empty ()) {
33+ if (dp_size_ > 1 ) {
34+ param.tensorParallelInfo .rank = dp_local_tp_rank_;
35+ param.tensorParallelInfo .worldSize = dp_local_tp_size_;
36+ } else {
37+ param.tensorParallelInfo .rank = parallel_args.rank ();
38+ param.tensorParallelInfo .worldSize = parallel_args.world_size ();
39+ }
40+ param.tensorParallelInfo .commDomain = std::to_string (dp_rank_);
41+ param.tensorParallelInfo .backend = FLAGS_communication_backend;
42+ } else {
43+ atb_speed::common::ParallelInfo parallelInfo =
44+ parallel_args.mapping ().Get (atb_speed::base::ATTN_TP);
45+ param.tensorParallelInfo .rank = parallelInfo.rank ;
46+ param.tensorParallelInfo .worldSize = parallelInfo.rankIds .size ();
47+ param.tensorParallelInfo .backend = FLAGS_communication_backend;
48+ parallelInfo.InitCommDomain (param.tensorParallelInfo .hcommInfo ,
49+ param.tensorParallelInfo .commDomain );
50+ }
4051 }
41- // param.linearParallelParam.tensorParallelInfo.backend =
42- // FLAGS_communication_backend;
43- param.tensorParallelInfo .commDomain = std::to_string (dp_rank_);
44- // param.tensorParallelInfo.rankTableFile = FLAGS_rank_tablefile;
4552}
4653
4754WordEmbeddingImpl::WordEmbeddingImpl (const ModelContext& context)
0 commit comments