Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions xformer/Transforms/OptimizeTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,70 @@ struct FoldDoubleTransposePattern : public OpRewritePattern<TFL::TransposeOp> {
}
};

// Erase Transpose and Inverse Transpose pair that the input dimension and output dimension are the same
struct EraseDoubleTransposePattern : public OpRewritePattern<TFL::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
bool IRModified = false;
// Get the permutation used in the transposes
DenseIntElementsAttr perm0;
if (!matchPattern(transposeOp.getPerm(), m_Constant(&perm0))) {
return failure();
}

SmallVector<Operation *> users(transposeOp->user_begin(), transposeOp->user_end());
bool allUserErased = true;
for (Operation *userOp : users) {
// Check if the user operation is a transpose op
auto userTransposeOp = dyn_cast<TFL::TransposeOp>(userOp);
if (!userTransposeOp) {
allUserErased = false;
continue;
}

// Get the permutation used in the user transposes
DenseIntElementsAttr perm1;
if (!matchPattern(userTransposeOp.getPerm(), m_Constant(&perm1))) {
allUserErased = false;
continue;
}

// Check if this is the inverse of parent transpose
int32_t correspondingDim = 0;
bool userIsInverseTranspose = true;
for (auto val : perm1.getValues<int32_t>()) {
if (correspondingDim != perm0.getValues<int32_t>()[val]) {
userIsInverseTranspose = false;
break;
}
correspondingDim += 1;
}

if (userIsInverseTranspose) {
// Can bypass this transpose -> inverse transpose pair
rewriter.replaceAllUsesWith(userTransposeOp.getResult(), transposeOp.getInput());
// And erase the inverse transpose ops
rewriter.eraseOp(userTransposeOp);
IRModified = true;
} else {
allUserErased = false;
}
}

if (allUserErased) {
// All user ops are the equivant inverse transpose ops
// Remove the transpose ops
rewriter.eraseOp(transposeOp);
IRModified = true;
}
if (IRModified) {
return success();
}
return failure();
}
};

// Replace TransposeOp with ReshapeOp if equivalent
// Transpose is equivalent to reshape if we only permute consecutive dimensions
// and only one of those permuted dimensions isn't of size 1
Expand Down Expand Up @@ -719,6 +783,14 @@ void OptimizeTranspose::runOnOperation() {
}

(void)applyPatternsAndFoldGreedily(func, std::move(patterns));

// Erase double transpose optimizations and merge if there is leftover
RewritePatternSet erasePatterns(ctx);
erasePatterns.insert<EraseDoubleTransposePattern>(ctx);
erasePatterns.insert<FoldDoubleTransposePattern>(ctx);
erasePatterns.insert<FoldTransposeToReshapePattern>(ctx);

(void)applyPatternsAndFoldGreedily(func, std::move(erasePatterns));
}
} // namespace

Expand Down