Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions examples/workers/l3/ep_dispatch_combine/kernels/aiv/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
* recovers column-0 values
* because columns [1, W_PAD)
* are zero by design);
* recv_idx_out is [L, R] (scalar copy of column 0)
* recv_idx_out is [L, R] (TROWSUM over [L,R,IDX_PAD]
* wide window, INT32)
*
* Design notes:
* - All cross-rank GM writes go through tile primitives (TPUT). No AIV
Expand All @@ -44,9 +45,8 @@
* - Weight uses TROWSUM along the W_PAD axis to compact the wide window
* [L, R, W_PAD] → [L, R] FP32: sum-of-row recovers slot [0] because the
* other lanes are zero. One TLOAD + TROWSUM + TSTORE per expert.
* - Idx uses scalar GM copy of column 0 to compact [L, R, IDX_PAD] →
* [L, R] INT32. INT32 TROWSUM exists in pto-isa but hangs on a2a3 in
* this configuration; the L*R = 128 scalar stores are negligible.
* - Idx uses the same TROWSUM compaction along the IDX_PAD axis to compact
* [L, R, IDX_PAD] → [L, R] INT32. One TLOAD + TROWSUM + TSTORE per expert.
*/

#ifndef __gm__
Expand Down Expand Up @@ -484,9 +484,9 @@ extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ in
TASSIGN(w_wide_tile, 0x10000);
TASSIGN(w_sum_tile, 0x20000);
TASSIGN(w_tmp_tile, 0x21000);
TASSIGN(idx_wide_tile, 0x30000);
TASSIGN(idx_sum_tile, 0x40000);
TASSIGN(idx_tmp_tile, 0x41000);
TASSIGN(idx_wide_tile, 0x10000);
TASSIGN(idx_sum_tile, 0x20000);
TASSIGN(idx_tmp_tile, 0x21000);

// Stage out x: per-row 1xD copies.
for (int e = 0; e < L; ++e) {
Expand Down Expand Up @@ -527,18 +527,30 @@ extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ in
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1);
}

// Stage out idx: scalar copy of column 0 from the wide window.
//
// ⚠ The same TROWSUM compaction used above for the FP32 weight channel
// does NOT work reliably for INT32 on a2a3: pto-isa declares INT32
// TROWSUM support, but with the same [R, IDX_PAD] / Layout::DN setup
// the kernel hangs on hardware. Until that path is stabilized, fall
// back to a scalar copy here. Volume is small (L*R = 128 INT32 stores)
// so the perf cost is negligible.
// Drain the weight loop's last TSTORE before reusing the same UB slots
// for idx_*. Without this fence, the idx TLOAD could overwrite UB while
// the trailing w TSTORE is still in flight on MTE3.
pipe_barrier(PIPE_ALL);

// Stage out idx: same TROWSUM compaction as the weight channel, on the
// INT32 [R, IDX_PAD] wide window. sum-along-PAD recovers slot [0] because
// columns [1, IDX_PAD) are zero by design.
for (int e = 0; e < L; ++e) {
for (int slot = 0; slot < R; ++slot) {
recv_idx_out[e * R + slot] = recv_idx_local[(e * R + slot) * IDX_PAD];
}
__gm__ int32_t *idx_win = recv_idx_local + e * R * IDX_PAD;
__gm__ int32_t *idx_out = recv_idx_out + e * R;
IWideG idx_win_g(idx_win);
ISumG idx_out_g(idx_out);
TLOAD(idx_wide_tile, idx_win_g);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1);
pipe_barrier(PIPE_V);
TROWSUM(idx_sum_tile, idx_wide_tile, idx_tmp_tile);
pipe_barrier(PIPE_V);
set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1);
TSTORE(idx_out_g, idx_sum_tile);
Comment thread
zhangqi-chen marked this conversation as resolved.
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1);
wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID1);
}
pipe_barrier(PIPE_ALL);
}
2 changes: 1 addition & 1 deletion examples/workers/l3/ep_dispatch_combine/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
[weight, 0, 0, …, 0]; receiver writes recv_w[loc_e][slot, :W_PAD]
and the kernel TROWSUM-compacts to a [L, R] FP32 host output.
- Idx uses the same minimum-tile rationale: 1xIDX_PAD=8 INT32 per
route, actual r=t*TOPK+k at slot [0]; compacted via scalar copy to
route, actual r=t*TOPK+k at slot [0]; TROWSUM-compacted to
[L, R] INT32 host output. Combine reads it to address
routed_y_buf[t, k, :] without a host-built origin_map.
- ``recv_count_out`` is [L, 1] INT32 emitted by dispatch's prefix_sum
Expand Down
Loading