Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,9 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {
std::vector<affine::AffineForOp> outerLoops;
std::vector<affine::AffineForOp> innerLoops;

bool isAInitialized = false;
bool isBInitialized = false;

// Find accumulation loops and set last outerloop
auto affineForOp = llvm::dyn_cast_or_null<affine::AffineForOp>(op->getParentRegion()->getParentOp());
while (affineForOp) {
Expand All @@ -347,16 +350,20 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {
}
affineForOp = llvm::dyn_cast_or_null<affine::AffineForOp>(affineForOp->getParentOp());
}

bool is_conv2d = (innerLoops.size() == 4);
assert(accumulationLoops.size()>=1);
assert(outerLoops.size()>=2);

affine::AffineForOp tile_k_w_loop, tile_o_h_loop, tile_o_w_loop;

if (is_conv2d) {
tile_k_w_loop = innerLoops.at(1);
tile_o_h_loop = innerLoops.at(2);
tile_o_w_loop = innerLoops.at(3);
rewriter.setInsertionPoint(&tile_k_w_loop.getBody()->back()); // to reuse CONV kernel
}

// Constants
Value c0 = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value rvl = rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(nr_element));
Expand All @@ -370,6 +377,12 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {
// Set Last outer loop
affineForOp = outerLoops.back();
int subtileM, subtileN, subtileK;

llvm::SmallVector<int32_t, 3> idxMap = {0, 1, 2};

if (auto idxMapAttr = op->getAttrOfType<mlir::DenseI32ArrayAttr>("idx_map"))
idxMap.assign(idxMapAttr.asArrayRef().begin(), idxMapAttr.asArrayRef().end());

affineForOp->walk([&](mlir::Operation *nestedOp) {
if (auto dmaStartOp = llvm::dyn_cast<memref::DmaStartOp>(nestedOp)) { // Replace DMAStartOp with actual `dma_start` op type
auto result = getDramMemRef(dmaStartOp);
Expand All @@ -392,35 +405,64 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {
auto blockArg = mlir::cast<mlir::BlockArgument>(result.first);
if (!blockArg)
return WalkResult::advance();
if (blockArg.getArgNumber() == 0) {
if (blockArg.getArgNumber() == idxMap[0]) {
ADmaTag = dmaStartOp.getTagMemRef(); // Assuming `getTag()` retrieves the `tag` from `dma_start`.
ADmaAsync = getAsyncValue(dmaStartOp);
llvm::SmallVector<mlir::Attribute> dmaSubtile = getSubtileSize(dmaStartOp);
subtileM = llvm::dyn_cast<mlir::IntegerAttr>(dmaSubtile[dmaSubtile.size() - 2]).getInt();
subtileK = llvm::dyn_cast<mlir::IntegerAttr>(dmaSubtile[dmaSubtile.size() - 1]).getInt();
} else if (blockArg.getArgNumber() == 1) {

isAInitialized = true;
} else if (blockArg.getArgNumber() == idxMap[1]) {
BDmaTag = dmaStartOp.getTagMemRef(); // Assuming `getTag()` retrieves the `tag` from `dma_start`.
BDmaAsync = getAsyncValue(dmaStartOp);
llvm::SmallVector<mlir::Attribute> dmaSubtile = getSubtileSize(dmaStartOp);
subtileN = llvm::dyn_cast<mlir::IntegerAttr>(dmaSubtile[dmaSubtile.size() - 1]).getInt();
} else if (blockArg.getArgNumber() == 2) {

isBInitialized = true;
} else if (blockArg.getArgNumber() == idxMap[2]) {
BiasDmaTag = dmaStartOp.getTagMemRef(); // Assuming `getTag()` retrieves the `tag` from `dma_start`.
BiasDmaAsync = getAsyncValue(dmaStartOp);
BiasDMAIndices = dmaStartOp.getTagIndices();
}
}
else if (auto vectorStoreOp = llvm::dyn_cast<affine::AffineVectorStoreOp>(nestedOp)) {
auto getOrigin = [](Value v) -> Value {
if (Operation *defOp = v.getDefiningOp()) {
if (auto reinterpret = llvm::dyn_cast<memref::ReinterpretCastOp>(defOp))
return reinterpret.getSource();
if (auto memcast = llvm::dyn_cast<memref::CastOp>(defOp))
return memcast.getSource();
}
return v;
};

Value sramRef = vectorStoreOp.getMemRef();
Value rootSramRef = getOrigin(sramRef);
Value rootA = getOrigin(A);
Value rootB = getOrigin(B);

if (rootSramRef == rootA)
isAInitialized = true;
else if (rootSramRef == rootB)
isBInitialized = true;
}

return WalkResult::advance();
});

int KStep = subtileK;
int push_length = subtileM > SYSTOLIC_SIZE ? SYSTOLIC_SIZE : subtileM;
int MStep = M > push_length ? push_length : M;
int NStep = subtileN > SYSTOLIC_SIZE ? SYSTOLIC_SIZE : subtileN;
Value vector_elements = rewriter.create<mlir::arith::ConstantIndexOp>(rewriter.getUnknownLoc(), push_length);

if (!ADmaTag || !BDmaTag) {
op.emitError () << "Failed to locate dma_start for retrieving tag.";
return failure();
if (!isAInitialized || !isBInitialized) {
op.emitError () << "Failed to locate data source for operands. Neither dma_start nor preceding vector_store found for A or B.";
return failure();
}

// Create inner loops for micro tile(SRAM <-> VRF).
affine::AffineForOp inner_loop;
Value zero_vector;
if (N > SYSTOLIC_SIZE) {
Expand Down Expand Up @@ -452,14 +494,17 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {
mlir::AffineExpr BTagExpr = rewriter.getAffineDimExpr(0) * -1;
llvm::SmallVector<mlir::Value, 4> ATagOperands = {accumulationLoops.at(0).getInductionVar()};
llvm::SmallVector<mlir::Value, 4> BTagOperands = {accumulationLoops.at(0).getInductionVar()};

for (size_t i = 1; i < numAccumulationLoops; ++i) {
ATagExpr = ATagExpr + rewriter.getAffineDimExpr(i) * -1;
BTagExpr = BTagExpr + rewriter.getAffineDimExpr(i) * -1;
ATagOperands.push_back(accumulationLoops.at(i).getInductionVar());
BTagOperands.push_back(accumulationLoops.at(i).getInductionVar());
}

int ADimOffset = numAccumulationLoops;
int BDimOffset = numAccumulationLoops;

if (is_conv2d) { // innerloop : K_H, K_W, O_H, O_W
/* FIXME. this is totally heuristic based lowering... */
int64_t kW;
Expand All @@ -471,6 +516,8 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {
rewriter.getAffineDimExpr(BDimOffset-2)*((N/subtileN)*(K/subtileK)*kW) + \
rewriter.getAffineDimExpr(BDimOffset-1)*((N/subtileN)*(K/subtileK));
}

// Create a dma_wait for B.
BTagExpr = BTagExpr + rewriter.getAffineDimExpr(BDimOffset).floorDiv((NStep+SYSTOLIC_SIZE-1)/SYSTOLIC_SIZE)*(K/KStep) + \
rewriter.getAffineDimExpr(BDimOffset+1).floorDiv((KStep+SYSTOLIC_SIZE-1)/SYSTOLIC_SIZE)*1;
auto BTagMap = mlir::AffineMap::get(BDimOffset+2, 0, BTagExpr);
Expand Down Expand Up @@ -499,6 +546,7 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {
}
rewriter.create<vcix::BinaryNoDestImmOp>(weight_vector.getLoc(), vwpush_opcode, weight_vector, zeroImmAttr, zeroImmAttr, rvl);
}

if (is_conv2d && inner_loop) { // return to inner loop location
tile_o_h_loop->moveBefore(inner_loop.getBody(), std::prev(inner_loop.getBody()->end()));
rewriter.setInsertionPointToStart(tile_o_h_loop.getBody());
Expand Down Expand Up @@ -561,6 +609,8 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {
ATagExpr = ATagExpr + rewriter.getAffineDimExpr(ADimOffset-2)*((K/subtileK)*(M/subtileM)*offset_h*coeff_h) + \
rewriter.getAffineDimExpr(ADimOffset-1)*((K/subtileK)*(M/subtileM)*offset_w);
}

// Create a dma_wait for A.
ATagExpr = ATagExpr + rewriter.getAffineDimExpr(ADimOffset)*(M/MStep) + \
rewriter.getAffineDimExpr(ADimOffset+1).floorDiv((MStep+SYSTOLIC_SIZE-1)/SYSTOLIC_SIZE);
auto ATagMap = mlir::AffineMap::get(ADimOffset+2, 0, ATagExpr);
Expand All @@ -571,6 +621,8 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {
auto ATagIdx = rewriter.create<affine::AffineApplyOp>(loc, ATagMap, ATagOperands);
if (ADmaAsync)
rewriter.create<memref::DmaWaitOp>(loc, ADmaTag, ValueRange{ATagIdx}, numElements);

// Create a dma_wait for Bias.
if (BiasDmaTag) {
/* Bias could be 1D or 2D */
Value first_index = BiasDMAIndices[0].getDefiningOp<mlir::arith::ConstantIndexOp>() ? c0 : n_tag_idx;
Expand All @@ -582,10 +634,9 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {
rewriter.create<memref::DmaWaitOp>(loc, BiasDmaTag, ValueRange{BiasTagIdx}, numElements);
}



// For vpush input loop part
int64_t M_LOOP = M > push_length ? push_length : M;

for (int i=0; i<M_LOOP; i+=nr_element) { // MxK
Value i_val = rewriter.create<mlir::arith::ConstantIndexOp>(rewriter.getUnknownLoc(), i);
Value spad_idx = rewriter.create<affine::AffineApplyOp>(loc, spadIdxMapAttr,
Expand All @@ -599,6 +650,7 @@ struct MatmulOpLowering : public OpRewritePattern<linalg::MatmulOp> {

// Compute instruction
rewriter.create<vcix::UnaryNoDestImmOp>(loc, compute_opcode, zeroImmAttr, compute_cycle, zeroImmAttr, sew, lmul, rvl);

// For vpop loop part
for (int i=0; i<M_LOOP; i+=nr_element) { // MxN
Value i_val = rewriter.create<mlir::arith::ConstantIndexOp>(rewriter.getUnknownLoc(), i);
Expand Down