Skip to content

[Pass Bug] insert-sync: missing MTE3->MTE2 hazard between pto.tstore and pto.comm.tput (TPUT reads src on MTE2) #706

@YunjiQin

Description

@YunjiQin

Pass / pipeline name

--enable-insert-sync (auto-sync pass) — missing MTE3 → MTE2 hazard between a
pto.tstore writing to a window-bound GM tensor and a subsequent pto.comm.tput
that consumes that same tensor as its src operand.

Summary

For a kernel that does pto.tstore to a local GM tensor followed by
pto.comm.tput reading that same tensor as src, --enable-insert-sync
emits set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0) right after the pto.tstore,
but places the matching wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0) after
the pto.comm.tput call (before a later, unrelated MTE2 op).

TputTransferOnce (pto-isa pto/comm/a2a3/TPut.hpp:46) opens with
TLOAD(stage, src) — a PIPE_MTE2 read of local src GM. The missing pre-TPUT
wait lets that internal TLOAD race the still-uncommitted pto.tstore. TPUT's
own internal set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0) / wait_flag(...) pair
collides with the outer pending set on the same event id.

Result: silent wrong output — pto.comm.tput pushes stale data to the peer.

End-to-end symptom (PyPTO tests/st/distributed/test_l3_put.py::TestL3Put::test_ring_shuffle,
2x a3 ranks, ring-overwrite):

outputs[0] = [100, 101, ..., 163]   # rank 1 -> rank 0 push OK (won the race)
outputs[1] = [  0,   0, ...,   0]   # rank 0 -> rank 1 push lost data
max diff = 63.0

Version

ptoas 0.41

Command

ptoas ring_step.pto --enable-insert-sync --pto-level=level3 --pto-arch=a3 -o ring_step.cpp

(invoked by PyPTO _get_ptoas_flags() in python/pypto/backend/pto_backend.py:818)

Before IR (input .pto, minimal -- strips Phase-3 notify/wait + Phase-4 readback)

module attributes {pto.target_arch = "a2a3"} {
  func.func @ring_step(%arg0: !pto.ptr<f32>, %arg1: !pto.ptr<f32>, %arg2: !pto.ptr<f32>,
                       %arg3: !pto.ptr<i64>, %arg4: i32)
      attributes {pto.kernel_kind = #pto.kernel_kind<vector>} {
    %c0_i64 = arith.constant 0  : i64
    %c1     = arith.constant 1  : index
    %c64    = arith.constant 64 : index
    %c0_idx = arith.constant 0  : index

    %inp_view = pto.make_tensor_view %arg0, shape = [%c1, %c64], strides = [%c64, %c1]
                {layout = #pto.layout<nd>} : !pto.tensor_view<?x?xf32>
    %src_view = pto.make_tensor_view %arg1, shape = [%c1, %c64], strides = [%c64, %c1]
                {layout = #pto.layout<nd>} : !pto.tensor_view<?x?xf32>
    %dst_view = pto.make_tensor_view %arg2, shape = [%c1, %c64], strides = [%c64, %c1]
                {layout = #pto.layout<nd>} : !pto.tensor_view<?x?xf32>

    %tile = pto.alloc_tile addr = %c0_i64 valid_row = %c1 valid_col = %c64
          : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=64, v_row=?, v_col=?,
                           blayout=row_major, slayout=none_box, fractal=512, pad=0>

    // Phase 1: TLOAD inp -> tile  (MTE2)
    %inp_p = pto.partition_view %inp_view, offsets = [%c0_idx, %c0_idx], sizes = [%c1, %c64]
           : !pto.tensor_view<?x?xf32> -> !pto.partition_tensor_view<1x64xf32>
    pto.tload ins(%inp_p) outs(%tile)

    // Phase 1: TSTORE tile -> local src GM  (MTE3)
    %src_p = pto.partition_view %src_view, offsets = [%c0_idx, %c0_idx], sizes = [%c1, %c64]
           : !pto.tensor_view<?x?xf32> -> !pto.partition_tensor_view<1x64xf32>
    pto.tstore ins(%tile) outs(%src_p)

    // Phase 2: TPUT -- reads local src GM (MTE2 inside TPUT) -> writes peer's dst GM
    %stage = pto.alloc_tile addr = %c0_i64 valid_row = %c1 valid_col = %c64
           : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=64, v_row=?, v_col=?,
                            blayout=row_major, slayout=none_box, fractal=512, pad=0>
    %peer_idx     = arith.index_cast %arg4 : i32 to index
    %off          = func.call @CommRemoteOffset_f32(%arg3, %peer_idx)
                    : (!pto.ptr<i64>, index) -> index
    %dst_peer_ptr = pto.addptr %arg2, %off : !pto.ptr<f32> -> !pto.ptr<f32>
    %stride       = arith.muli %c1, %c64 : index
    %dst_peer_view = pto.make_tensor_view %dst_peer_ptr, shape = [%c1, %c64],
                     strides = [%stride, %c1] {layout = #pto.layout<nd>}
                   : !pto.tensor_view<?x?xf32>

    %dst_peer_p  = pto.partition_view %dst_peer_view, offsets = [%c0_idx, %c0_idx],
                   sizes = [%c1, %c64]
                 : !pto.tensor_view<?x?xf32> -> !pto.partition_tensor_view<1x64xf32>
    %src_local_p = pto.partition_view %src_view, offsets = [%c0_idx, %c0_idx],
                   sizes = [%c1, %c64]
                 : !pto.tensor_view<?x?xf32> -> !pto.partition_tensor_view<1x64xf32>
    pto.comm.tput(%dst_peer_p, %src_local_p, buf(%stage)
                  : !pto.partition_tensor_view<1x64xf32>,
                    !pto.partition_tensor_view<1x64xf32>,
                    !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=64, v_row=?, v_col=?,
                                  blayout=row_major, slayout=none_box, fractal=512, pad=0>)
        {atomicType = #pto<atomic_type atomic_none>}
    return
  }
  func.func private @CommRemoteOffset_f32(%ctx: !pto.ptr<i64>, %peer: index) -> index { /* ... */ }
}

The two partition_view ops into %src_view (%src_p for the tstore, %src_local_p
for the tput's src operand) target the same underlying tensor view but are
distinct SSA values -- shape similar to #536 / #533.

Expected behavior

A wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_IDx) should land before
pto.comm.tput, pairing with the set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_IDx)
emitted after pto.tstore. TPUT's internal sync should not collide with the
outer event id -- either reserve a distinct event id for TPUT's internal pair,
or close the outer pair before entering TPUT.

Manual verification: adding a single pipe_barrier(PIPE_ALL) at the top of
TputTransferOnce (pto-isa pto/comm/a2a3/TPut.hpp:46) makes the test pass.
That is a workaround at the library level; the proper fix is in the
insert-sync pass.

Actual IR / generated kernel

Relevant slice of the generated ring_step.cpp (full file lines 92-131):

TLOAD(v16, v20);                              // Phase 1 TLOAD inp  (MTE2)
set_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
// ... build src GlobalTensor v23 wrapping %arg1 (local src GM) ...
wait_flag(PIPE_MTE2, PIPE_MTE3, EVENT_ID0);
TSTORE(v23, v16);                             // Phase 1 TSTORE -> local src GM  (MTE3)
set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);    // (!) no matching wait before TPUT

// ... build dst-peer GlobalTensor v32, allocate staging v25 ...

pto::comm::TPUT(v32, v23, v25);               // (!) TPUT's internal TLOAD on v23
                                              //     races the still-pending TSTORE;
                                              //     also issues its own
                                              //     set_flag/wait_flag pair on
                                              //     EVENT_ID0, colliding with the outer set.

// ... unrelated Phase 3 / Phase 4 dst-side build code ...

wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0);   // (!) consumed AFTER TPUT --
                                              //     gates Phase 4 TLOAD(dst), not the TPUT
TLOAD(v41, v45);                              // Phase 4 TLOAD local dst (MTE2)

Two problems:

  1. The MTE3->MTE2 hazard from pto.tstore -> pto.comm.tput's internal TLOAD
    of the same GM tensor is not honored. TPUT's first internal op is
    TLOAD(stage, src) on PIPE_MTE2 -- the pass doesn't model that read.
  2. EVENT_ID0 is reused by TPUT's internal sync, so the outer
    set_flag(MTE3, MTE2, EVENT_ID0) and the later
    wait_flag(MTE3, MTE2, EVENT_ID0) no longer form a valid pair.

Root cause hypothesis

--enable-insert-sync doesn't classify pto.comm.tput as reading its src
operand on PIPE_MTE2 and writing its dst operand on PIPE_MTE3. The pass
treats TPUT as opaque, so the producer->consumer edge from any prior PIPE_MTE3
writer (typically a pto.tstore) to TPUT's PIPE_MTE2 read of the same GM
region is never inserted.

Additionally, event-id allocation does not exclude the ids used by TPUT's
internal sync (EVENT_ID0 in TputTransferOnce), so when the pass does emit
outer pairs around a TPUT call, they collide.

Suggested fix

In the operand effect table consumed by --enable-insert-sync, classify
pto.comm.tput:

  • src operand: PIPE_MTE2 read of GM (mirrors pto.tload's source effect);
  • dst operand: PIPE_MTE3 write of GM (cross-NPU GM write);
  • buf operand: tile_buf scratch -- same UB lifetime as a TPUT_IMPL ping/pong tile.

With those edges, the pass would emit wait_flag(MTE3, MTE2, EVENT_IDx)
before pto.comm.tput whenever a prior pto.tstore (or any MTE3 GM writer)
targets the same %src partition view.

Either reserve a distinct event id for TPUT's internal sync, or surface its
event id through the TPUT intrinsic so the pass can route around it.

Mirror op pto.comm.tget likely has the dual issue on its output side (its
internal final TSTORE writes local dst GM, which any later reader of that
GM region must wait for). Not yet validated end-to-end -- flag pending.

Target arch

a3

PTOAS build level

level3

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions