From 9c34bc572125aae4af4a7450602888acd988039d Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Wed, 27 May 2026 11:29:31 +0800 Subject: [PATCH 1/3] Fix p2p comm insert sync modeling --- .../Transforms/InsertSync/PTOIRTranslator.h | 7 + .../Transforms/InsertSync/PTOIRTranslator.cpp | 66 ++++++++-- .../lit/pto/issue706_comm_p2p_insert_sync.pto | 122 ++++++++++++++++++ 3 files changed, 185 insertions(+), 10 deletions(-) create mode 100644 test/lit/pto/issue706_comm_p2p_insert_sync.pto diff --git a/include/PTO/Transforms/InsertSync/PTOIRTranslator.h b/include/PTO/Transforms/InsertSync/PTOIRTranslator.h index 9f329e517..b354e2034 100644 --- a/include/PTO/Transforms/InsertSync/PTOIRTranslator.h +++ b/include/PTO/Transforms/InsertSync/PTOIRTranslator.h @@ -88,6 +88,13 @@ class PTOIRTranslator { // --- 核心:处理计算/搬运指令 (生成 Compound 节点) --- void UpdatePTOOpInfo(Operation *op); + void UpdateP2PCommOpInfo(pto::TPutOp op); + void UpdateP2PCommOpInfo(pto::TGetOp op); + void AddPTOOpInfo(Operation *op, PipelineType pipe, ValueRange defs, + ValueRange uses); + void AddCompoundOpInfo(Operation *op, PipelineType pipe, + SmallVector defVec, + SmallVector useVec); // --- 辅助函数 --- diff --git a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp index 0f3dba41d..7a7e9d667 100644 --- a/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp +++ b/lib/PTO/Transforms/InsertSync/PTOIRTranslator.cpp @@ -360,6 +360,10 @@ void PTOIRTranslator::RecursionIR(Region *region) { return WalkResult::skip(); } else if (auto yieldOp = dyn_cast(op)) { UpdateYieldOpInfo(yieldOp); + } else if (auto tputOp = dyn_cast(op)) { + UpdateP2PCommOpInfo(tputOp); + } else if (auto tgetOp = dyn_cast(op)) { + UpdateP2PCommOpInfo(tgetOp); } else if (isa(op)) { // --- Case D: 带有 OpPipeInterface 的计算/搬运指令 --- UpdatePTOOpInfo(op); @@ -555,17 +559,59 @@ void PTOIRTranslator::UpdatePTOOpInfo(Operation *op) { << " has Pipe but no MemoryEffects interface.\n"); } - // 3. 构建 Compound Node + AddCompoundOpInfo(op, pipe, std::move(defVec), std::move(useVec)); +} + +// ============================================================================ +// 6. Model compound p2p communication ops +// ============================================================================ +// TPUT/TGET hide a TLOAD-to-staging and TSTORE-from-staging sequence inside the +// PTO-ISA helper. Model those pipe effects at the call boundary for auto-sync. +void PTOIRTranslator::UpdateP2PCommOpInfo(pto::TPutOp op) { + SmallVector scratch{op.getPing()}; + if (Value pong = op.getPong()) + scratch.push_back(pong); + + AddPTOOpInfo(op.getOperation(), PipelineType::PIPE_MTE2, scratch, + ValueRange{op.getSrc()}); + AddPTOOpInfo(op.getOperation(), PipelineType::PIPE_MTE3, + ValueRange{op.getDst()}, scratch); +} + +void PTOIRTranslator::UpdateP2PCommOpInfo(pto::TGetOp op) { + SmallVector scratch{op.getPing()}; + if (Value pong = op.getPong()) + scratch.push_back(pong); + + AddPTOOpInfo(op.getOperation(), PipelineType::PIPE_MTE2, scratch, + ValueRange{op.getSrc()}); + AddPTOOpInfo(op.getOperation(), PipelineType::PIPE_MTE3, + ValueRange{op.getDst()}, scratch); +} + +void PTOIRTranslator::AddPTOOpInfo(Operation *op, PipelineType pipe, + ValueRange defs, ValueRange uses) { + if (pipe == pto::PipelineType::PIPE_UNASSIGNED) + return; + + SmallVector defVec; + SmallVector useVec; + UpdateDefUseVec(defs, defVec); + UpdateDefUseVec(uses, useVec); + AddCompoundOpInfo(op, pipe, std::move(defVec), std::move(useVec)); +} + +void PTOIRTranslator::AddCompoundOpInfo( + Operation *op, PipelineType pipe, SmallVector defVec, + SmallVector useVec) { auto compoundElement = std::make_unique( - index, defVec, useVec, pipe, op->getName()); + index, std::move(defVec), std::move(useVec), pipe, op->getName()); compoundElement->elementOp = op; - // 4. 设置 Core Type (用于区分 Cube/Vector 资源) - // Matmul (M) 和 L1->L0 搬运 (MTE1) 通常涉及 Cube 资源 - if (pipe == pto::PipelineType::PIPE_M || pipe == pto::PipelineType::PIPE_MTE1) { + if (pipe == pto::PipelineType::PIPE_M || + pipe == pto::PipelineType::PIPE_MTE1) { compoundElement->compoundCoreType = pto::TCoreType::CUBE; } else { - // MTE2, MTE3, Vector 归类为 Vector Core (或者对应 MTE 资源) compoundElement->compoundCoreType = pto::TCoreType::VECTOR; } @@ -574,7 +620,7 @@ void PTOIRTranslator::UpdatePTOOpInfo(Operation *op) { } // ============================================================================ -// 6. [P0 修改] 获取 Op 的 Pipeline 类型 +// 7. [P0 修改] 获取 Op 的 Pipeline 类型 // ============================================================================ pto::PipelineType PTOIRTranslator::getOpPipeline(Operation *op) { // 1. 优先尝试通过接口获取 @@ -589,7 +635,7 @@ pto::PipelineType PTOIRTranslator::getOpPipeline(Operation *op) { } // ============================================================================ -// 7. 控制流处理 (SCF Support) +// 8. 控制流处理 (SCF Support) // ============================================================================ void PTOIRTranslator::UpdateForOpInfo(scf::ForOp forOp) { @@ -719,7 +765,7 @@ void PTOIRTranslator::UpdateYieldOpInfo(scf::YieldOp yieldOp) { } // ============================================================================ -// 8. 辅助函数 +// 9. 辅助函数 // ============================================================================ void PTOIRTranslator::UpdateAliasBufferInfo(Value result, Value source) { if (!result || !source) return; @@ -940,7 +986,7 @@ void PTOIRTranslator::UpdateDefUseVec(ValueRange values, SmallVector, %local_src: !pto.ptr, + %remote_dst: !pto.ptr) + attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0_i64 = arith.constant 0 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %input_view = pto.make_tensor_view %input, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + %src_view = pto.make_tensor_view %local_src, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + %dst_view = pto.make_tensor_view %remote_dst, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + + %input_part = pto.partition_view %input_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %src_store_part = pto.partition_view %src_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %src_tput_part = pto.partition_view %src_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + + %tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %stage = pto.alloc_tile addr = %c4096_i64 + : !pto.tile_buf + + pto.tload ins(%input_part : !pto.partition_tensor_view<64xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%src_store_part : !pto.partition_tensor_view<64xf32>) + pto.comm.tput(%dst_part, %src_tput_part, buf(%stage) + : !pto.partition_tensor_view<64xf32>, + !pto.partition_tensor_view<64xf32>, + !pto.tile_buf) + {atomicType = #pto} + return + } + + func.func @issue706_tget_orders_later_tload( + %local_dst: !pto.ptr, %remote_src: !pto.ptr, + %sink: !pto.ptr) + attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0_i64 = arith.constant 0 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %dst_view = pto.make_tensor_view %local_dst, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + %src_view = pto.make_tensor_view %remote_src, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + %sink_view = pto.make_tensor_view %sink, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + + %dst_tget_part = pto.partition_view %dst_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %dst_load_part = pto.partition_view %dst_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %sink_part = pto.partition_view %sink_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + + %stage = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %tile = pto.alloc_tile addr = %c4096_i64 + : !pto.tile_buf + + pto.comm.tget(%dst_tget_part, %src_part, buf(%stage) + : !pto.partition_tensor_view<64xf32>, + !pto.partition_tensor_view<64xf32>, + !pto.tile_buf) + pto.tload ins(%dst_load_part : !pto.partition_tensor_view<64xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%sink_part : !pto.partition_tensor_view<64xf32>) + return + } +} + +// CHECK-LABEL: AICORE void issue706_tput_waits_for_prior_tstore( +// CHECK: TSTORE( +// CHECK-NEXT: set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID[[TPUT_ID:[0-9]+]]); +// CHECK: wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID[[TPUT_ID]]); +// CHECK-NEXT: pto::comm::TPUT( + +// CHECK-LABEL: AICORE void issue706_tget_orders_later_tload( +// CHECK: pto::comm::TGET( +// CHECK-NEXT: set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID[[TGET_ID:[0-9]+]]); +// CHECK-NEXT: wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID[[TGET_ID]]); +// CHECK-NEXT: TLOAD( From da369461884716845174c23c9c26f3c3b3b63195 Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Wed, 27 May 2026 15:00:30 +0800 Subject: [PATCH 2/3] Handle p2p comm in graph sync solver --- .../GraphSyncSolver/SyncSolverIRTranslator.h | 3 + .../SyncSolverIRTranslator.cpp | 26 +++- .../pto/issue706_comm_p2p_insert_sync_gss.pto | 122 ++++++++++++++++++ 3 files changed, 150 insertions(+), 1 deletion(-) create mode 100644 test/lit/pto/issue706_comm_p2p_insert_sync_gss.pto diff --git a/include/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.h b/include/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.h index b17175c1e..3b388ecf3 100644 --- a/include/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.h +++ b/include/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.h @@ -89,6 +89,9 @@ class IRTranslator { std::unique_ptr getPipeInterfaceOp(pto::OpPipeInterface op, OperationBase *parentOp); + template + std::unique_ptr getP2PCommOp(OP op, OperationBase *parentOp); + std::unique_ptr getTensorExtractOp(tensor::ExtractOp extractOp, OperationBase *parentOp); diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp index a4cc5dd09..ef211a2a2 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp @@ -192,6 +192,24 @@ IRTranslator::getPipeInterfaceOp(pto::OpPipeInterface op, pipeWrite, reads, writes); } +template +std::unique_ptr +IRTranslator::getP2PCommOp(OP commOp, OperationBase *parentOp) { + llvm::SmallVector readVals{commOp.getSrc(), commOp.getPing()}; + llvm::SmallVector writeVals{commOp.getDst(), commOp.getPing()}; + if (Value pong = commOp.getPong()) { + readVals.push_back(pong); + writeVals.push_back(pong); + } + auto reads = getMemoryOps(readVals); + auto writes = getMemoryOps(writeVals); + + // Synchronous TPUT/TGET hide MTE2 staging and MTE3 commit inside one call. + return std::make_unique( + commOp.getOperation(), parentOp, TCoreType::CUBE_OR_VECTOR, + pto::PIPE::PIPE_MTE2, pto::PIPE::PIPE_MTE3, reads, writes); +} + std::unique_ptr IRTranslator::getTensorExtractOp(tensor::ExtractOp extractOp, OperationBase *parentOp) { @@ -309,7 +327,13 @@ std::unique_ptr IRTranslator::funcIrBuilder(Region ®ion, continue; } - if (auto pipeOp = dyn_cast(op)) { + if (auto tputOp = dyn_cast(op)) { + if (auto rw = getP2PCommOp(tputOp, parScope)) + parScope->body.push_back(std::move(rw)); + } else if (auto tgetOp = dyn_cast(op)) { + if (auto rw = getP2PCommOp(tgetOp, parScope)) + parScope->body.push_back(std::move(rw)); + } else if (auto pipeOp = dyn_cast(op)) { if (auto rw = getPipeInterfaceOp(pipeOp, parScope)) parScope->body.push_back(std::move(rw)); } else if (auto storeOp = dyn_cast(op)) { diff --git a/test/lit/pto/issue706_comm_p2p_insert_sync_gss.pto b/test/lit/pto/issue706_comm_p2p_insert_sync_gss.pto new file mode 100644 index 000000000..3a9618bbf --- /dev/null +++ b/test/lit/pto/issue706_comm_p2p_insert_sync_gss.pto @@ -0,0 +1,122 @@ +// RUN: ptoas --pto-arch=a3 --pto-level=level3 --enable-graph-sync-solver --graph-sync-solver-event-id-max=64 %s | FileCheck %s + +module attributes {pto.target_arch = "a2a3"} { + func.func @issue706_tput_waits_for_prior_tstore( + %input: !pto.ptr, %local_src: !pto.ptr, + %remote_dst: !pto.ptr) + attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0_i64 = arith.constant 0 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %input_view = pto.make_tensor_view %input, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + %src_view = pto.make_tensor_view %local_src, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + %dst_view = pto.make_tensor_view %remote_dst, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + + %input_part = pto.partition_view %input_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %src_store_part = pto.partition_view %src_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %src_tput_part = pto.partition_view %src_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + + %tile = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %stage = pto.alloc_tile addr = %c4096_i64 + : !pto.tile_buf + + pto.tload ins(%input_part : !pto.partition_tensor_view<64xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%src_store_part : !pto.partition_tensor_view<64xf32>) + pto.comm.tput(%dst_part, %src_tput_part, buf(%stage) + : !pto.partition_tensor_view<64xf32>, + !pto.partition_tensor_view<64xf32>, + !pto.tile_buf) + {atomicType = #pto} + return + } + + func.func @issue706_tget_orders_later_tload( + %local_dst: !pto.ptr, %remote_src: !pto.ptr, + %sink: !pto.ptr) + attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0_i64 = arith.constant 0 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %dst_view = pto.make_tensor_view %local_dst, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + %src_view = pto.make_tensor_view %remote_src, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + %sink_view = pto.make_tensor_view %sink, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + + %dst_tget_part = pto.partition_view %dst_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %dst_load_part = pto.partition_view %dst_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %src_part = pto.partition_view %src_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %sink_part = pto.partition_view %sink_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + + %stage = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %tile = pto.alloc_tile addr = %c4096_i64 + : !pto.tile_buf + + pto.comm.tget(%dst_tget_part, %src_part, buf(%stage) + : !pto.partition_tensor_view<64xf32>, + !pto.partition_tensor_view<64xf32>, + !pto.tile_buf) + pto.tload ins(%dst_load_part : !pto.partition_tensor_view<64xf32>) + outs(%tile : !pto.tile_buf) + pto.tstore ins(%tile : !pto.tile_buf) + outs(%sink_part : !pto.partition_tensor_view<64xf32>) + return + } +} + +// CHECK-LABEL: AICORE void issue706_tput_waits_for_prior_tstore( +// CHECK: TSTORE( +// CHECK-NEXT: set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID[[TPUT_ID:[0-9]+]]); +// CHECK: wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID[[TPUT_ID]]); +// CHECK-NEXT: pto::comm::TPUT( + +// CHECK-LABEL: AICORE void issue706_tget_orders_later_tload( +// CHECK: pto::comm::TGET( +// CHECK-NEXT: set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID[[TGET_ID:[0-9]+]]); +// CHECK-NEXT: wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID[[TGET_ID]]); +// CHECK-NEXT: TLOAD( From a4afa97e07c9988789a36cebde1f9bd28ace0b45 Mon Sep 17 00:00:00 2001 From: zhangstevenunity <128771452+zhangstevenunity@users.noreply.github.com> Date: Wed, 27 May 2026 16:44:53 +0800 Subject: [PATCH 3/3] Split GSS p2p staging phases --- .../GraphSyncSolver/SyncSolverIRTranslator.h | 3 +- .../Transforms/GraphSyncSolver/SyncSolver.cpp | 7 +++ .../SyncSolverIRTranslator.cpp | 31 ++++++----- .../pto/issue706_comm_p2p_insert_sync_gss.pto | 51 +++++++++++++++++++ 4 files changed, 74 insertions(+), 18 deletions(-) diff --git a/include/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.h b/include/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.h index 3b388ecf3..ee9196810 100644 --- a/include/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.h +++ b/include/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.h @@ -89,8 +89,7 @@ class IRTranslator { std::unique_ptr getPipeInterfaceOp(pto::OpPipeInterface op, OperationBase *parentOp); - template - std::unique_ptr getP2PCommOp(OP op, OperationBase *parentOp); + template void appendP2PCommOps(OP op, Scope *parentOp); std::unique_ptr getTensorExtractOp(tensor::ExtractOp extractOp, OperationBase *parentOp); diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp index 23a4032a6..72f924443 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolver.cpp @@ -2322,6 +2322,12 @@ void Solver::processConflict(Occurrence *occ1, Occurrence *occ2, } } +static bool isInternalP2PCommPhasePair(RWOperation *rwOp1, + RWOperation *rwOp2) { + return rwOp1 != rwOp2 && rwOp1->op && rwOp1->op == rwOp2->op && + isa(rwOp1->op); +} + // Main processing loop that iterates processingOrders and attempts to // discover and record conflicts. void Solver::processOrders() { @@ -2334,6 +2340,7 @@ void Solver::processOrders() { } if (checkImpossibleOccPair(occ1, occ2) || checkAlreadySynced(occ1, occ2) || skipMMad1DecomposedLoopOpt(occ1, occ2) || + isInternalP2PCommPhasePair(rwOp1, rwOp2) || checkSkipParallelLoop(occ1, occ2) || checkSkipCrossCorePair(occ1, occ2)) { continue; diff --git a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp index ef211a2a2..3c51a99ae 100644 --- a/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp +++ b/lib/PTO/Transforms/GraphSyncSolver/SyncSolverIRTranslator.cpp @@ -193,21 +193,22 @@ IRTranslator::getPipeInterfaceOp(pto::OpPipeInterface op, } template -std::unique_ptr -IRTranslator::getP2PCommOp(OP commOp, OperationBase *parentOp) { - llvm::SmallVector readVals{commOp.getSrc(), commOp.getPing()}; - llvm::SmallVector writeVals{commOp.getDst(), commOp.getPing()}; - if (Value pong = commOp.getPong()) { - readVals.push_back(pong); - writeVals.push_back(pong); - } - auto reads = getMemoryOps(readVals); - auto writes = getMemoryOps(writeVals); +void IRTranslator::appendP2PCommOps(OP commOp, Scope *parentOp) { + llvm::SmallVector scratch{commOp.getPing()}; + if (Value pong = commOp.getPong()) + scratch.push_back(pong); // Synchronous TPUT/TGET hide MTE2 staging and MTE3 commit inside one call. - return std::make_unique( + // Keep the phases separate so scratch writes are attributed to MTE2 and + // scratch reads are attributed to MTE3. + parentOp->body.push_back(std::make_unique( commOp.getOperation(), parentOp, TCoreType::CUBE_OR_VECTOR, - pto::PIPE::PIPE_MTE2, pto::PIPE::PIPE_MTE3, reads, writes); + pto::PIPE::PIPE_MTE2, pto::PIPE::PIPE_MTE2, + getMemoryOps({commOp.getSrc()}), getMemoryOps(scratch))); + parentOp->body.push_back(std::make_unique( + commOp.getOperation(), parentOp, TCoreType::CUBE_OR_VECTOR, + pto::PIPE::PIPE_MTE3, pto::PIPE::PIPE_MTE3, getMemoryOps(scratch), + getMemoryOps({commOp.getDst()}))); } std::unique_ptr @@ -328,11 +329,9 @@ std::unique_ptr IRTranslator::funcIrBuilder(Region ®ion, } if (auto tputOp = dyn_cast(op)) { - if (auto rw = getP2PCommOp(tputOp, parScope)) - parScope->body.push_back(std::move(rw)); + appendP2PCommOps(tputOp, parScope); } else if (auto tgetOp = dyn_cast(op)) { - if (auto rw = getP2PCommOp(tgetOp, parScope)) - parScope->body.push_back(std::move(rw)); + appendP2PCommOps(tgetOp, parScope); } else if (auto pipeOp = dyn_cast(op)) { if (auto rw = getPipeInterfaceOp(pipeOp, parScope)) parScope->body.push_back(std::move(rw)); diff --git a/test/lit/pto/issue706_comm_p2p_insert_sync_gss.pto b/test/lit/pto/issue706_comm_p2p_insert_sync_gss.pto index 3a9618bbf..6bd18bbbd 100644 --- a/test/lit/pto/issue706_comm_p2p_insert_sync_gss.pto +++ b/test/lit/pto/issue706_comm_p2p_insert_sync_gss.pto @@ -107,6 +107,51 @@ module attributes {pto.target_arch = "a2a3"} { outs(%sink_part : !pto.partition_tensor_view<64xf32>) return } + + func.func @issue706_tput_stage_reuse_waits_on_mte2( + %local_src: !pto.ptr, %remote_dst: !pto.ptr) + attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0_i64 = arith.constant 0 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + + %src_view = pto.make_tensor_view %local_src, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + %dst_view = pto.make_tensor_view %remote_dst, shape = [%c64], + strides = [%c1] {layout = #pto.layout} + : !pto.tensor_view<64xf32> + + %src_part = pto.partition_view %src_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + %dst_part = pto.partition_view %dst_view, offsets = [%c0], + sizes = [%c64] + : !pto.tensor_view<64xf32> + -> !pto.partition_tensor_view<64xf32> + + %stage = pto.alloc_tile addr = %c0_i64 + : !pto.tile_buf + %tmp = pto.alloc_tile addr = %c4096_i64 + : !pto.tile_buf + %out = pto.alloc_tile addr = %c8192_i64 + : !pto.tile_buf + + pto.tadd ins(%stage, %tmp + : !pto.tile_buf, + !pto.tile_buf) + outs(%out : !pto.tile_buf) + pto.comm.tput(%dst_part, %src_part, buf(%stage) + : !pto.partition_tensor_view<64xf32>, + !pto.partition_tensor_view<64xf32>, + !pto.tile_buf) + {atomicType = #pto} + return + } } // CHECK-LABEL: AICORE void issue706_tput_waits_for_prior_tstore( @@ -120,3 +165,9 @@ module attributes {pto.target_arch = "a2a3"} { // CHECK-NEXT: set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID[[TGET_ID:[0-9]+]]); // CHECK-NEXT: wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID[[TGET_ID]]); // CHECK-NEXT: TLOAD( + +// CHECK-LABEL: AICORE void issue706_tput_stage_reuse_waits_on_mte2( +// CHECK: TADD( +// CHECK-NEXT: set_flag(PIPE_V, PIPE_MTE2, EVENT_ID[[STAGE_ID:[0-9]+]]); +// CHECK-NEXT: wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID[[STAGE_ID]]); +// CHECK-NEXT: pto::comm::TPUT(