Skip to content

Commit 6b246e2

Browse files
feat: add moe all2all kernels and deep ep layer. (#497)
Co-authored-by: phantomlei3 <phantomlei3@gmail.com>
1 parent cb2443e commit 6b246e2

File tree

11 files changed

+1205
-11
lines changed

11 files changed

+1205
-11
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
18+
namespace xllm::kernel::mlu {
19+
20+
void gather_split(const torch::Tensor& input,
21+
const torch::Tensor& gather_index,
22+
const torch::Tensor& valid_token_num,
23+
const torch::Tensor& output_head,
24+
const torch::Tensor& output_tail) {
25+
tmo::torch_api::gather_split(
26+
output_head, output_tail, input, gather_index, valid_token_num);
27+
}
28+
29+
} // namespace xllm::kernel::mlu

xllm/core/kernels/mlu/mlu_ops_api.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,46 @@ torch::Tensor moe_combine_result(
182182
const int64_t expert_size,
183183
const std::optional<torch::Tensor>& bias);
184184

185+
torch::Tensor moe_all2all_gen_send_layout(const torch::Tensor& token_count,
186+
int64_t nrank);
187+
188+
std::vector<torch::Tensor> moe_all2all_gen_gather_index(
189+
const torch::Tensor& token_num,
190+
int64_t pad_num,
191+
bool return_cusum_token_count);
192+
193+
std::vector<torch::Tensor> moe_all2all_create(int64_t dispatch_token_byte,
194+
int64_t combine_token_byte,
195+
int64_t max_expert_num,
196+
int64_t max_token_num,
197+
int64_t rank,
198+
int64_t nrank,
199+
const torch::Device& device);
200+
201+
void moe_all2all_init(int64_t handle,
202+
const torch::Tensor& all_exchange_info,
203+
const torch::Device& device);
204+
205+
void moe_all2all_dispatch(int64_t handle,
206+
int64_t token_byte,
207+
int64_t token_num,
208+
const torch::Tensor& send_layout,
209+
const torch::Tensor& send_token_num,
210+
const torch::Tensor& recv_layout,
211+
const torch::Tensor& recv_token_num,
212+
const std::optional<torch::Tensor>& send_token,
213+
const std::optional<torch::Tensor>& recv_token);
214+
215+
void moe_all2all_combine(int64_t handle,
216+
int64_t token_byte,
217+
int64_t token_num,
218+
const torch::Tensor& send_src_layout,
219+
const torch::Tensor& send_dst_layout,
220+
const std::optional<torch::Tensor>& send_token,
221+
const std::optional<torch::Tensor>& recv_token);
222+
223+
void moe_all2all_destroy(int64_t handle, const torch::Device& device);
224+
185225
std::tuple<torch::Tensor, torch::Tensor> scaled_quantize(
186226
const torch::Tensor& x,
187227
const torch::Tensor& smooth,
@@ -222,4 +262,10 @@ torch::Tensor apply_top_k_top_p(const torch::Tensor& logits,
222262

223263
torch::Tensor random_sample(const torch::Tensor& probs);
224264

265+
void gather_split(const torch::Tensor& input,
266+
const torch::Tensor& gather_index,
267+
const torch::Tensor& valid_token_num,
268+
const torch::Tensor& output_head,
269+
const torch::Tensor& output_tail);
270+
225271
} // namespace xllm::kernel::mlu
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlu_ops_api.h"
17+
18+
namespace xllm::kernel::mlu {
19+
20+
torch::Tensor moe_all2all_gen_send_layout(const torch::Tensor& token_count,
21+
int64_t nrank) {
22+
return tmo::torch_api::moe_all2all_gen_send_layout(token_count, nrank);
23+
}
24+
25+
std::vector<torch::Tensor> moe_all2all_gen_gather_index(
26+
const torch::Tensor& token_num,
27+
int64_t pad_num,
28+
bool return_cusum_token_count) {
29+
// get dimension information
30+
int32_t rank_num = token_num.size(0);
31+
int32_t expert_num = token_num.size(1);
32+
33+
// prepare tensor options (keep same device as input, enforce int32)
34+
auto options =
35+
torch::TensorOptions().dtype(torch::kInt).device(token_num.device());
36+
37+
// output tensors
38+
torch::Tensor gather_by_expert_index =
39+
torch::empty({rank_num * pad_num}, options);
40+
torch::Tensor gather_by_rank_index =
41+
torch::empty({rank_num * pad_num}, options);
42+
torch::Tensor token_count = torch::empty({expert_num}, options);
43+
torch::Tensor token_sum = torch::empty({1}, options);
44+
45+
// handle optional tensor allocation
46+
torch::Tensor cusum_token_count;
47+
if (return_cusum_token_count) {
48+
cusum_token_count = torch::empty({expert_num + 1}, options);
49+
}
50+
51+
tmo::torch_api::moe_all2all_gen_gather_index(gather_by_expert_index,
52+
gather_by_rank_index,
53+
token_count,
54+
cusum_token_count,
55+
token_sum,
56+
token_num,
57+
pad_num);
58+
59+
// pack and return results using std::vector
60+
std::vector<torch::Tensor> results;
61+
results.reserve(return_cusum_token_count ? 5 : 4);
62+
63+
results.push_back(gather_by_expert_index);
64+
results.push_back(gather_by_rank_index);
65+
results.push_back(token_count);
66+
results.push_back(token_sum);
67+
68+
if (return_cusum_token_count) {
69+
results.push_back(cusum_token_count);
70+
}
71+
72+
return results;
73+
}
74+
75+
std::vector<torch::Tensor> moe_all2all_create(int64_t dispatch_token_byte,
76+
int64_t combine_token_byte,
77+
int64_t max_expert_num,
78+
int64_t max_token_num,
79+
int64_t rank,
80+
int64_t nrank,
81+
const torch::Device& device) {
82+
// Create placeholder tensor on the specified device
83+
auto options = torch::TensorOptions().device(device);
84+
torch::Tensor place_holder = torch::empty({0}, options);
85+
86+
// Call the underlying operator
87+
// Since the return type is explicitly std::vector<torch::Tensor>, we capture
88+
// it directly.
89+
std::vector<torch::Tensor> outputs =
90+
tmo::torch_api::moe_all2all_create(dispatch_token_byte,
91+
combine_token_byte,
92+
max_expert_num,
93+
max_token_num,
94+
rank,
95+
nrank,
96+
place_holder);
97+
// Return all 6 tensors
98+
// Construct a new vector from the iterator range
99+
return std::vector<torch::Tensor>(outputs.begin(), outputs.end());
100+
}
101+
102+
void moe_all2all_init(int64_t handle,
103+
const torch::Tensor& all_exchange_info,
104+
const torch::Device& device) {
105+
auto options = torch::TensorOptions().device(device);
106+
torch::Tensor place_holder = torch::empty({0}, options);
107+
tmo::torch_api::moe_all2all_init(handle, all_exchange_info, place_holder);
108+
}
109+
110+
void moe_all2all_dispatch(int64_t handle,
111+
int64_t token_byte,
112+
int64_t token_num,
113+
const torch::Tensor& send_layout,
114+
const torch::Tensor& send_token_num,
115+
const torch::Tensor& recv_layout,
116+
const torch::Tensor& recv_token_num,
117+
const std::optional<torch::Tensor>& send_token,
118+
const std::optional<torch::Tensor>& recv_token) {
119+
tmo::torch_api::moe_all2all_dispatch(handle,
120+
token_byte,
121+
token_num,
122+
send_layout,
123+
send_token_num,
124+
recv_layout,
125+
recv_token_num,
126+
send_token,
127+
recv_token);
128+
}
129+
130+
void moe_all2all_combine(int64_t handle,
131+
int64_t token_byte,
132+
int64_t token_num,
133+
const torch::Tensor& send_src_layout,
134+
const torch::Tensor& send_dst_layout,
135+
const std::optional<torch::Tensor>& send_token,
136+
const std::optional<torch::Tensor>& recv_token) {
137+
tmo::torch_api::moe_all2all_combine(handle,
138+
token_byte,
139+
token_num,
140+
send_src_layout,
141+
send_dst_layout,
142+
send_token,
143+
recv_token);
144+
}
145+
146+
void moe_all2all_destroy(int64_t handle, const torch::Device& device) {
147+
auto options = torch::TensorOptions().device(device);
148+
torch::Tensor place_holder = torch::empty({0}, options);
149+
tmo::torch_api::moe_all2all_destroy(handle, place_holder);
150+
}
151+
152+
} // namespace xllm::kernel::mlu

xllm/core/kernels/ops_api.cpp

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,6 @@ torch::Tensor group_gemm(GroupGemmParams& params) {
303303
params.trans_a,
304304
params.trans_b,
305305
params.a_quant_bit);
306-
#elif defined(USE_CUDA)
307-
LOG(FATAL) << "group_gemm for cuda not implemented";
308306
#else
309307
LOG(FATAL) << "group_gemm not implemented";
310308
#endif
@@ -323,8 +321,6 @@ std::tuple<torch::Tensor, torch::Tensor> moe_active_topk(
323321
params.scoring_func,
324322
params.route_scale,
325323
params.e_score_correction_bias);
326-
#elif defined(USE_CUDA)
327-
LOG(FATAL) << "moe_active_topk for cuda not implemented";
328324
#else
329325
LOG(FATAL) << "moe_active_topk not implemented";
330326
#endif
@@ -333,8 +329,6 @@ std::tuple<torch::Tensor, torch::Tensor> moe_active_topk(
333329
std::vector<torch::Tensor> moe_gen_idx(MoeGenIdxParams& params) {
334330
#if defined(USE_MLU)
335331
return mlu::moe_gen_idx(params.expert_id, params.expert_num);
336-
#elif defined(USE_CUDA)
337-
LOG(FATAL) << "moe_gen_idx for cuda not implemented";
338332
#else
339333
LOG(FATAL) << "moe_gen_idx not implemented";
340334
#endif
@@ -347,8 +341,6 @@ torch::Tensor moe_expand_input(MoeExpandInputParams& params) {
347341
params.cusum_token_count,
348342
params.start_expert_id,
349343
params.expert_size);
350-
#elif defined(USE_CUDA)
351-
LOG(FATAL) << "moe_expand_input for cuda not implemented";
352344
#else
353345
LOG(FATAL) << "moe_expand_input not implemented";
354346
#endif
@@ -364,13 +356,90 @@ torch::Tensor moe_combine_result(MoeCombineResultParams& params) {
364356
params.start_expert_id,
365357
params.expert_size,
366358
params.bias);
367-
#elif defined(USE_CUDA)
368-
LOG(FATAL) << "moe_combine_result for cuda not implemented";
369359
#else
370360
LOG(FATAL) << "moe_combine_result not implemented";
371361
#endif
372362
}
373363

364+
torch::Tensor moe_all2all_gen_send_layout(
365+
MoeAll2AllGenSendLayoutParams& params) {
366+
#if defined(USE_MLU)
367+
return mlu::moe_all2all_gen_send_layout(params.token_count, params.nrank);
368+
#else
369+
LOG(FATAL) << "moe_all2all_gen_send_layout not implemented";
370+
#endif
371+
}
372+
373+
std::vector<torch::Tensor> moe_all2all_gen_gather_index(
374+
MoeAll2AllGenGatherIndexParams& params) {
375+
#if defined(USE_MLU)
376+
return mlu::moe_all2all_gen_gather_index(
377+
params.token_num, params.pad_num, params.return_cusum_token_count);
378+
#else
379+
LOG(FATAL) << "moe_all2all_gen_gather_index not implemented";
380+
#endif
381+
}
382+
383+
std::vector<torch::Tensor> moe_all2all_create(MoeAll2AllCreateParams& params) {
384+
#if defined(USE_MLU)
385+
return mlu::moe_all2all_create(params.dispatch_token_byte,
386+
params.combine_token_byte,
387+
params.max_expert_num,
388+
params.max_token_num,
389+
params.rank,
390+
params.nrank,
391+
params.device);
392+
#else
393+
LOG(FATAL) << "moe_all2all_create not implemented";
394+
#endif
395+
}
396+
397+
void moe_all2all_init(MoeAll2AllInitParams& params) {
398+
#if defined(USE_MLU)
399+
mlu::moe_all2all_init(params.handle, params.all_exchange_info, params.device);
400+
#else
401+
LOG(FATAL) << "moe_all2all_init not implemented";
402+
#endif
403+
}
404+
405+
void moe_all2all_dispatch(MoeAll2AllDispatchParams& params) {
406+
#if defined(USE_MLU)
407+
mlu::moe_all2all_dispatch(params.handle,
408+
params.token_byte,
409+
params.token_num,
410+
params.send_layout,
411+
params.send_token_num,
412+
params.recv_layout,
413+
params.recv_token_num,
414+
params.send_token,
415+
params.recv_token);
416+
#else
417+
LOG(FATAL) << "moe_all2all_dispatch not implemented";
418+
#endif
419+
}
420+
421+
void moe_all2all_combine(MoeAll2AllCombineParams& params) {
422+
#if defined(USE_MLU)
423+
mlu::moe_all2all_combine(params.handle,
424+
params.token_byte,
425+
params.token_num,
426+
params.send_src_layout,
427+
params.send_dst_layout,
428+
params.send_token,
429+
params.recv_token);
430+
#else
431+
LOG(FATAL) << "moe_all2all_combine not implemented";
432+
#endif
433+
}
434+
435+
void moe_all2all_destroy(MoeAll2AllDestroyParams& params) {
436+
#if defined(USE_MLU)
437+
mlu::moe_all2all_destroy(params.handle, params.device);
438+
#else
439+
LOG(FATAL) << "moe_all2all_destroy not implemented";
440+
#endif
441+
}
442+
374443
std::tuple<torch::Tensor, torch::Tensor> scaled_quantize(
375444
ScaledQuantizeParams& params) {
376445
#if defined(USE_MLU)
@@ -455,4 +524,16 @@ void masked_indexer_select_paged_kv(MaskedIndexerSelectPagedKVParams& params) {
455524
#endif
456525
}
457526

527+
void gather_split(GatherSplitParams& params) {
528+
#if defined(USE_MLU)
529+
mlu::gather_split(params.input,
530+
params.gather_index,
531+
params.valid_token_num,
532+
params.output_head,
533+
params.output_tail);
534+
#else
535+
LOG(FATAL) << "gather_split not implemented";
536+
#endif
537+
}
538+
458539
} // namespace xllm::kernel

0 commit comments

Comments
 (0)