diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index aa52ed7f5..9d8dbade9 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -182,6 +182,70 @@ struct FoldDoubleTransposePattern : public OpRewritePattern { } }; +// Erase Transpose and Inverse Transpose pair that the input dimension and output dimension are the same +struct EraseDoubleTransposePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + bool IRModified = false; + // Get the permutation used in the transposes + DenseIntElementsAttr perm0; + if (!matchPattern(transposeOp.getPerm(), m_Constant(&perm0))) { + return failure(); + } + + SmallVector users(transposeOp->user_begin(), transposeOp->user_end()); + bool allUserErased = true; + for (Operation *userOp : users) { + // Check if the user operation is a transpose op + auto userTransposeOp = dyn_cast(userOp); + if (!userTransposeOp) { + allUserErased = false; + continue; + } + + // Get the permutation used in the user transposes + DenseIntElementsAttr perm1; + if (!matchPattern(userTransposeOp.getPerm(), m_Constant(&perm1))) { + allUserErased = false; + continue; + } + + // Check if this is the inverse of parent transpose + int32_t correspondingDim = 0; + bool userIsInverseTranspose = true; + for (auto val : perm1.getValues()) { + if (correspondingDim != perm0.getValues()[val]) { + userIsInverseTranspose = false; + break; + } + correspondingDim += 1; + } + + if (userIsInverseTranspose) { + // Can bypass this transpose -> inverse transpose pair + rewriter.replaceAllUsesWith(userTransposeOp.getResult(), transposeOp.getInput()); + // And erase the inverse transpose ops + rewriter.eraseOp(userTransposeOp); + IRModified = true; + } else { + allUserErased = false; + } + } + + if (allUserErased) { + // All user ops are the equivant inverse transpose ops + // Remove the transpose ops + rewriter.eraseOp(transposeOp); + IRModified = true; + } + if (IRModified) { + return success(); + } + return failure(); + } +}; + // Replace TransposeOp with ReshapeOp if equivalent // Transpose is equivalent to reshape if we only permute consecutive dimensions // and only one of those permuted dimensions isn't of size 1 @@ -719,6 +783,14 @@ void OptimizeTranspose::runOnOperation() { } (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + + // Erase double transpose optimizations and merge if there is leftover + RewritePatternSet erasePatterns(ctx); + erasePatterns.insert(ctx); + erasePatterns.insert(ctx); + erasePatterns.insert(ctx); + + (void)applyPatternsAndFoldGreedily(func, std::move(erasePatterns)); } } // namespace