diff --git a/mlir/test/lib/Conversion/PyTorchSimToVCIX/TestPyTorchSimToVCIXConversion.cpp b/mlir/test/lib/Conversion/PyTorchSimToVCIX/TestPyTorchSimToVCIXConversion.cpp index 0fcb105094f4..77ae16314814 100644 --- a/mlir/test/lib/Conversion/PyTorchSimToVCIX/TestPyTorchSimToVCIXConversion.cpp +++ b/mlir/test/lib/Conversion/PyTorchSimToVCIX/TestPyTorchSimToVCIXConversion.cpp @@ -333,6 +333,9 @@ struct MatmulOpLowering : public OpRewritePattern { std::vector outerLoops; std::vector innerLoops; + bool isAInitialized = false; + bool isBInitialized = false; + // Find accumulation loops and set last outerloop auto affineForOp = llvm::dyn_cast_or_null(op->getParentRegion()->getParentOp()); while (affineForOp) { @@ -347,16 +350,20 @@ struct MatmulOpLowering : public OpRewritePattern { } affineForOp = llvm::dyn_cast_or_null(affineForOp->getParentOp()); } + bool is_conv2d = (innerLoops.size() == 4); assert(accumulationLoops.size()>=1); assert(outerLoops.size()>=2); + affine::AffineForOp tile_k_w_loop, tile_o_h_loop, tile_o_w_loop; + if (is_conv2d) { tile_k_w_loop = innerLoops.at(1); tile_o_h_loop = innerLoops.at(2); tile_o_w_loop = innerLoops.at(3); rewriter.setInsertionPoint(&tile_k_w_loop.getBody()->back()); // to reuse CONV kernel } + // Constants Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value rvl = rewriter.create(loc, rewriter.getI64IntegerAttr(nr_element)); @@ -370,6 +377,12 @@ struct MatmulOpLowering : public OpRewritePattern { // Set Last outer loop affineForOp = outerLoops.back(); int subtileM, subtileN, subtileK; + + llvm::SmallVector idxMap = {0, 1, 2}; + + if (auto idxMapAttr = op->getAttrOfType("idx_map")) + idxMap.assign(idxMapAttr.asArrayRef().begin(), idxMapAttr.asArrayRef().end()); + affineForOp->walk([&](mlir::Operation *nestedOp) { if (auto dmaStartOp = llvm::dyn_cast(nestedOp)) { // Replace DMAStartOp with actual `dma_start` op type auto result = getDramMemRef(dmaStartOp); @@ -392,35 +405,64 @@ struct MatmulOpLowering : public OpRewritePattern { auto blockArg = mlir::cast(result.first); if (!blockArg) return WalkResult::advance(); - if (blockArg.getArgNumber() == 0) { + if (blockArg.getArgNumber() == idxMap[0]) { ADmaTag = dmaStartOp.getTagMemRef(); // Assuming `getTag()` retrieves the `tag` from `dma_start`. ADmaAsync = getAsyncValue(dmaStartOp); llvm::SmallVector dmaSubtile = getSubtileSize(dmaStartOp); subtileM = llvm::dyn_cast(dmaSubtile[dmaSubtile.size() - 2]).getInt(); subtileK = llvm::dyn_cast(dmaSubtile[dmaSubtile.size() - 1]).getInt(); - } else if (blockArg.getArgNumber() == 1) { + + isAInitialized = true; + } else if (blockArg.getArgNumber() == idxMap[1]) { BDmaTag = dmaStartOp.getTagMemRef(); // Assuming `getTag()` retrieves the `tag` from `dma_start`. BDmaAsync = getAsyncValue(dmaStartOp); llvm::SmallVector dmaSubtile = getSubtileSize(dmaStartOp); subtileN = llvm::dyn_cast(dmaSubtile[dmaSubtile.size() - 1]).getInt(); - } else if (blockArg.getArgNumber() == 2) { + + isBInitialized = true; + } else if (blockArg.getArgNumber() == idxMap[2]) { BiasDmaTag = dmaStartOp.getTagMemRef(); // Assuming `getTag()` retrieves the `tag` from `dma_start`. BiasDmaAsync = getAsyncValue(dmaStartOp); BiasDMAIndices = dmaStartOp.getTagIndices(); } + } + else if (auto vectorStoreOp = llvm::dyn_cast(nestedOp)) { + auto getOrigin = [](Value v) -> Value { + if (Operation *defOp = v.getDefiningOp()) { + if (auto reinterpret = llvm::dyn_cast(defOp)) + return reinterpret.getSource(); + if (auto memcast = llvm::dyn_cast(defOp)) + return memcast.getSource(); + } + return v; + }; + + Value sramRef = vectorStoreOp.getMemRef(); + Value rootSramRef = getOrigin(sramRef); + Value rootA = getOrigin(A); + Value rootB = getOrigin(B); + + if (rootSramRef == rootA) + isAInitialized = true; + else if (rootSramRef == rootB) + isBInitialized = true; } + return WalkResult::advance(); }); + int KStep = subtileK; int push_length = subtileM > SYSTOLIC_SIZE ? SYSTOLIC_SIZE : subtileM; int MStep = M > push_length ? push_length : M; int NStep = subtileN > SYSTOLIC_SIZE ? SYSTOLIC_SIZE : subtileN; Value vector_elements = rewriter.create(rewriter.getUnknownLoc(), push_length); - if (!ADmaTag || !BDmaTag) { - op.emitError () << "Failed to locate dma_start for retrieving tag."; - return failure(); + if (!isAInitialized || !isBInitialized) { + op.emitError () << "Failed to locate data source for operands. Neither dma_start nor preceding vector_store found for A or B."; + return failure(); } + + // Create inner loops for micro tile(SRAM <-> VRF). affine::AffineForOp inner_loop; Value zero_vector; if (N > SYSTOLIC_SIZE) { @@ -452,14 +494,17 @@ struct MatmulOpLowering : public OpRewritePattern { mlir::AffineExpr BTagExpr = rewriter.getAffineDimExpr(0) * -1; llvm::SmallVector ATagOperands = {accumulationLoops.at(0).getInductionVar()}; llvm::SmallVector BTagOperands = {accumulationLoops.at(0).getInductionVar()}; + for (size_t i = 1; i < numAccumulationLoops; ++i) { ATagExpr = ATagExpr + rewriter.getAffineDimExpr(i) * -1; BTagExpr = BTagExpr + rewriter.getAffineDimExpr(i) * -1; ATagOperands.push_back(accumulationLoops.at(i).getInductionVar()); BTagOperands.push_back(accumulationLoops.at(i).getInductionVar()); } + int ADimOffset = numAccumulationLoops; int BDimOffset = numAccumulationLoops; + if (is_conv2d) { // innerloop : K_H, K_W, O_H, O_W /* FIXME. this is totally heuristic based lowering... */ int64_t kW; @@ -471,6 +516,8 @@ struct MatmulOpLowering : public OpRewritePattern { rewriter.getAffineDimExpr(BDimOffset-2)*((N/subtileN)*(K/subtileK)*kW) + \ rewriter.getAffineDimExpr(BDimOffset-1)*((N/subtileN)*(K/subtileK)); } + + // Create a dma_wait for B. BTagExpr = BTagExpr + rewriter.getAffineDimExpr(BDimOffset).floorDiv((NStep+SYSTOLIC_SIZE-1)/SYSTOLIC_SIZE)*(K/KStep) + \ rewriter.getAffineDimExpr(BDimOffset+1).floorDiv((KStep+SYSTOLIC_SIZE-1)/SYSTOLIC_SIZE)*1; auto BTagMap = mlir::AffineMap::get(BDimOffset+2, 0, BTagExpr); @@ -499,6 +546,7 @@ struct MatmulOpLowering : public OpRewritePattern { } rewriter.create(weight_vector.getLoc(), vwpush_opcode, weight_vector, zeroImmAttr, zeroImmAttr, rvl); } + if (is_conv2d && inner_loop) { // return to inner loop location tile_o_h_loop->moveBefore(inner_loop.getBody(), std::prev(inner_loop.getBody()->end())); rewriter.setInsertionPointToStart(tile_o_h_loop.getBody()); @@ -561,6 +609,8 @@ struct MatmulOpLowering : public OpRewritePattern { ATagExpr = ATagExpr + rewriter.getAffineDimExpr(ADimOffset-2)*((K/subtileK)*(M/subtileM)*offset_h*coeff_h) + \ rewriter.getAffineDimExpr(ADimOffset-1)*((K/subtileK)*(M/subtileM)*offset_w); } + + // Create a dma_wait for A. ATagExpr = ATagExpr + rewriter.getAffineDimExpr(ADimOffset)*(M/MStep) + \ rewriter.getAffineDimExpr(ADimOffset+1).floorDiv((MStep+SYSTOLIC_SIZE-1)/SYSTOLIC_SIZE); auto ATagMap = mlir::AffineMap::get(ADimOffset+2, 0, ATagExpr); @@ -571,6 +621,8 @@ struct MatmulOpLowering : public OpRewritePattern { auto ATagIdx = rewriter.create(loc, ATagMap, ATagOperands); if (ADmaAsync) rewriter.create(loc, ADmaTag, ValueRange{ATagIdx}, numElements); + + // Create a dma_wait for Bias. if (BiasDmaTag) { /* Bias could be 1D or 2D */ Value first_index = BiasDMAIndices[0].getDefiningOp() ? c0 : n_tag_idx; @@ -582,10 +634,9 @@ struct MatmulOpLowering : public OpRewritePattern { rewriter.create(loc, BiasDmaTag, ValueRange{BiasTagIdx}, numElements); } - - // For vpush input loop part int64_t M_LOOP = M > push_length ? push_length : M; + for (int i=0; i(rewriter.getUnknownLoc(), i); Value spad_idx = rewriter.create(loc, spadIdxMapAttr, @@ -599,6 +650,7 @@ struct MatmulOpLowering : public OpRewritePattern { // Compute instruction rewriter.create(loc, compute_opcode, zeroImmAttr, compute_cycle, zeroImmAttr, sew, lmul, rvl); + // For vpop loop part for (int i=0; i(rewriter.getUnknownLoc(), i);