From 59d4a7c6d9b00aa365a36a4157223efb9054872b Mon Sep 17 00:00:00 2001 From: LeslieXMOS Date: Thu, 23 Apr 2026 20:49:12 +0800 Subject: [PATCH 1/3] Added erase pattern for double redundant transpose ops --- xformer/Transforms/OptimizeTranspose.cpp | 51 ++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index aa52ed7f5..237fe2792 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -182,6 +182,51 @@ struct FoldDoubleTransposePattern : public OpRewritePattern { } }; +// Erase Transpose and Inverse Transpose pair that the input dimension and output dimension are the same +struct EraseDoubleTransposePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + // Get the permutation used in the transposes + DenseIntElementsAttr perm0; + if (!matchPattern(transposeOp.getPerm(), m_Constant(&perm0))) + return failure(); + + SmallVector users(transposeOp->user_begin(), transposeOp->user_end()); + for (Operation *userOp : users) { + // Check if the user operation is a transpose op + auto userTransposeOp = dyn_cast(userOp); + if (!userTransposeOp) { + return failure(); + } + + // Get the permutation used in the user transposes + DenseIntElementsAttr perm1; + if (!matchPattern(userTransposeOp.getPerm(), m_Constant(&perm1))) + return failure(); + + // Check if this is the inverse of parent transpose + int32_t correspondingDim = 0; + for (auto val : perm1.getValues()) { + if (correspondingDim != perm0.getValues()[val]) { + return failure(); + } + correspondingDim += 1; + } + } + + // Reaching this stage means all user ops are exact inverse transpose ops + // Removing all of them + for (Operation *userOp : users) { + auto userTransposeOp = dyn_cast(userOp); + rewriter.replaceAllUsesWith(userTransposeOp.getResult(), transposeOp.getInput()); + rewriter.eraseOp(userOp); + } + rewriter.eraseOp(transposeOp); + return success(); + } +}; + // Replace TransposeOp with ReshapeOp if equivalent // Transpose is equivalent to reshape if we only permute consecutive dimensions // and only one of those permuted dimensions isn't of size 1 @@ -719,6 +764,12 @@ void OptimizeTranspose::runOnOperation() { } (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + + // Erase double transpose optimizations + RewritePatternSet erasePatterns(ctx); + erasePatterns.insert(ctx); + + (void)applyPatternsAndFoldGreedily(func, std::move(erasePatterns)); } } // namespace From 9b25937bea9f51316a96ec9cff10efcc8b26286f Mon Sep 17 00:00:00 2001 From: LeslieXMOS Date: Thu, 23 Apr 2026 23:03:44 +0800 Subject: [PATCH 2/3] Got a better approach that fit more cases after a night run --- xformer/Transforms/OptimizeTranspose.cpp | 47 +++++++++++++++++------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index 237fe2792..e37e90056 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -187,43 +187,60 @@ struct EraseDoubleTransposePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::TransposeOp transposeOp, PatternRewriter &rewriter) const override { + bool IRModified = false; // Get the permutation used in the transposes DenseIntElementsAttr perm0; - if (!matchPattern(transposeOp.getPerm(), m_Constant(&perm0))) + if (!matchPattern(transposeOp.getPerm(), m_Constant(&perm0))) { return failure(); + } SmallVector users(transposeOp->user_begin(), transposeOp->user_end()); + bool allUserErased = true; for (Operation *userOp : users) { // Check if the user operation is a transpose op auto userTransposeOp = dyn_cast(userOp); if (!userTransposeOp) { - return failure(); + allUserErased = false; + continue; } // Get the permutation used in the user transposes DenseIntElementsAttr perm1; - if (!matchPattern(userTransposeOp.getPerm(), m_Constant(&perm1))) - return failure(); + if (!matchPattern(userTransposeOp.getPerm(), m_Constant(&perm1))) { + allUserErased = false; + continue; + } // Check if this is the inverse of parent transpose int32_t correspondingDim = 0; + bool userIsInverseTranspose = true; for (auto val : perm1.getValues()) { if (correspondingDim != perm0.getValues()[val]) { - return failure(); + userIsInverseTranspose = false; + break; } correspondingDim += 1; } + + if (userIsInverseTranspose) { + // Can bypass this transpose -> inverse transpose pair + rewriter.replaceAllUsesWith(userTransposeOp.getResult(), transposeOp.getInput()); + // And erase the inverse transpose ops + rewriter.eraseOp(userTransposeOp); + IRModified = true; + } } - // Reaching this stage means all user ops are exact inverse transpose ops - // Removing all of them - for (Operation *userOp : users) { - auto userTransposeOp = dyn_cast(userOp); - rewriter.replaceAllUsesWith(userTransposeOp.getResult(), transposeOp.getInput()); - rewriter.eraseOp(userOp); + if (allUserErased) { + // All user ops are the equivant inverse transpose ops + // Remove the transpose ops + rewriter.eraseOp(transposeOp); + IRModified = true; } - rewriter.eraseOp(transposeOp); - return success(); + if (IRModified) { + return success(); + } + return failure(); } }; @@ -765,9 +782,11 @@ void OptimizeTranspose::runOnOperation() { (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); - // Erase double transpose optimizations + // Erase double transpose optimizations and merge if there is leftover RewritePatternSet erasePatterns(ctx); erasePatterns.insert(ctx); + erasePatterns.insert(ctx); + erasePatterns.insert(ctx); (void)applyPatternsAndFoldGreedily(func, std::move(erasePatterns)); } From 261b7abf57478e6e655ca8d5f1603279d3a4dc36 Mon Sep 17 00:00:00 2001 From: LeslieXMOS Date: Fri, 24 Apr 2026 00:28:00 +0800 Subject: [PATCH 3/3] Bug fix: prevent transpose which still have user got erased --- xformer/Transforms/OptimizeTranspose.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xformer/Transforms/OptimizeTranspose.cpp b/xformer/Transforms/OptimizeTranspose.cpp index e37e90056..9d8dbade9 100644 --- a/xformer/Transforms/OptimizeTranspose.cpp +++ b/xformer/Transforms/OptimizeTranspose.cpp @@ -228,6 +228,8 @@ struct EraseDoubleTransposePattern : public OpRewritePattern { // And erase the inverse transpose ops rewriter.eraseOp(userTransposeOp); IRModified = true; + } else { + allUserErased = false; } }