3636
3737#include < llvm/Support/Debug.h>
3838
39+ #include < iostream>
3940#include < memory>
4041
4142namespace mlir {
@@ -45,16 +46,161 @@ namespace gc {
4546
4647namespace {
4748
49+ struct FlashAttentionConfig {
50+ int RowBlock, ColumnBlock;
51+ };
52+
53+ static FlashAttentionConfig
54+ getDefaultFlashAttentionConfig (linalgx::ScaledDotProductAttentionOp &sdpaOp) {
55+ // TODO: allow tuning
56+ auto seqLen = sdpaOp.getShape (sdpaOp.getDpsInputOperand (0 ))[2 ];
57+ FlashAttentionConfig cfg;
58+
59+ // cfg.RowBlock = seqLen / 64;
60+ // cfg.ColBlock = seqLen / 64;
61+ return cfg;
62+ }
63+
4864struct MHAToFlashAttention
49- : public OpInterfaceRewritePattern<linalg::LinalgOp> {
50- using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
65+ : public OpRewritePattern<linalgx::ScaledDotProductAttentionOp> {
66+ using OpRewritePattern<
67+ linalgx::ScaledDotProductAttentionOp>::OpRewritePattern;
68+
69+ struct OuterLoopGenerationResult {
70+ // / Tiled operations that are generated during tiling. The order does not
71+ // / matter except the last op. The replacements are expected to be the
72+ // / results of the last op.
73+ SmallVector<Operation *> tiledOps;
74+ // / The `scf.for` operations that iterate over the tiles.
75+ SmallVector<LoopLikeOpInterface> loops;
76+ SmallVector<LoopLikeOpInterface> reductionLoops;
77+ };
5178
52- LogicalResult matchAndRewrite (linalg::LinalgOp linalgOp,
79+ // FailureOr<OuterLoopGenerationResult>
80+ // outerLoopGeneration(RewriterBase &rewriter, linalg::LinalgOp linalgOp)
81+ // const {
82+ // SmallVector<unsigned> RowDimPos, ColDimPos;
83+ // linalgOp.getReductionDims(KDimPos);
84+ // getMatmulParallelDims(linalgOp, 0, MDimPos);
85+ // getMatmulParallelDims(linalgOp, 1, NDimPos);
86+
87+ // OuterLoopGenerationOption option;
88+ // auto iteratorTypes = linalgOp.getIteratorTypesArray();
89+ // auto KFirstDim = (int)getOprandDim(linalgOp, KDimPos[0], 1);
90+ // auto MFirstDim = (int)getOprandDim(linalgOp, MDimPos[0], 0);
91+ // auto NFirstDim = (int)getOprandDim(linalgOp, NDimPos[0], 1);
92+ // auto KParallelBlockSize =
93+ // KDimPos.size() > 1
94+ // ? divAndCeil(KFirstDim, cfg.KThreads)
95+ // : divAndCeil(divAndCeil(KFirstDim, cfg.KBlock), cfg.KThreads) *
96+ // cfg.KBlock;
97+ // auto MParallelBlockSize =
98+ // MDimPos.size() > 1
99+ // ? divAndCeil(MFirstDim, cfg.MThreads)
100+ // : divAndCeil(divAndCeil(MFirstDim, cfg.MBlock), cfg.MThreads) *
101+ // cfg.MBlock;
102+ // auto NParallelBlockSize =
103+ // NDimPos.size() > 1
104+ // ? divAndCeil(NFirstDim, cfg.NThreads)
105+ // : divAndCeil(divAndCeil(NFirstDim, cfg.NBlock), cfg.NThreads) *
106+ // cfg.NBlock;
107+ // auto KOuterBlockSize = KDimPos.size() > 1
108+ // ? (cfg.KBlock - 1) / cfg.innerMostKBlock + 1
109+ // : cfg.KBlock;
110+ // auto MOuterBlockSize = MDimPos.size() > 1
111+ // ? (cfg.MBlock - 1) / cfg.innerMostMBlock + 1
112+ // : cfg.MBlock;
113+ // auto NOuterBlockSize = NDimPos.size() > 1
114+ // ? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1
115+ // : cfg.NBlock;
116+ // // Outer
117+ // option.nestedTileSizes.emplace_back(SmallVector<int>{
118+ // MParallelBlockSize, NParallelBlockSize, KParallelBlockSize});
119+ // option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp);
120+ // option.loopDim.emplace_back(
121+ // SmallVector<int>{(int)MDimPos[0], (int)NDimPos[0], (int)KDimPos[0]});
122+ // // Middle
123+ // for (auto [tile, dim] :
124+ // llvm::zip(SmallVector<int>{MOuterBlockSize, NOuterBlockSize,
125+ // KOuterBlockSize},
126+ // SmallVector<int>{(int)MDimPos[0], (int)NDimPos[0],
127+ // (int)KDimPos[0]})) {
128+ // option.nestedTileSizes.emplace_back(SmallVector<int>{tile});
129+ // option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
130+ // option.loopDim.emplace_back(SmallVector<int>{dim});
131+ // }
132+ // // Inner
133+ // if (KDimPos.size() == 1) {
134+ // option.nestedTileSizes.emplace_back(SmallVector<int>{cfg.KBlock});
135+ // option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
136+ // option.loopDim.emplace_back(SmallVector<int>{(int)KDimPos.back()});
137+ // }
138+ // if (MDimPos.size() == 1) {
139+ // option.nestedTileSizes.emplace_back(
140+ // SmallVector<int>{cfg.innerMostMBlock});
141+ // option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
142+ // option.loopDim.emplace_back(SmallVector<int>{(int)MDimPos.back()});
143+ // }
144+ // if (NDimPos.size() == 1) {
145+ // option.nestedTileSizes.emplace_back(
146+ // SmallVector<int>{cfg.innerMostNBlock});
147+ // option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
148+ // option.loopDim.emplace_back(SmallVector<int>{(int)NDimPos.back()});
149+ // }
150+ // for (auto dim = 0UL; dim < linalgOp.getNumLoops(); dim++) {
151+ // if (dim != MDimPos.back() && dim != NDimPos.back() &&
152+ // iteratorTypes[dim] != mlir::utils::IteratorType::reduction) {
153+ // option.nestedTileSizes.emplace_back(SmallVector<int>{1});
154+ // option.loopType.emplace_back(
155+ // OuterLoopGenerationOption::LoopType::ForOp);
156+ // option.loopDim.emplace_back(SmallVector<int>{(int)dim});
157+ // }
158+ // }
159+
160+ // auto lowPrecisionCast =
161+ // [&](RewriterBase &rewriter, Location loc,
162+ // linalg::LinalgOp linalgop) -> FailureOr<linalg::LinalgOp> {
163+ // auto legalizedResult = matmulDtypeLegalize(
164+ // rewriter, linalgop.getOperation(), !hasFillOp, true);
165+ // if (legalizedResult->castOp && legalizedResult->linalgOp) {
166+ // auto linalgOp = legalizedResult->linalgOp;
167+ // rewriter.replaceOp(linalgop,
168+ // linalgOp->getResult(linalgOp->getNumResults() -
169+ // 1));
170+ // return dyn_cast<linalg::LinalgOp>(linalgOp);
171+ // }
172+ // return failure();
173+ // };
174+ // option.innermostFullResultCallBacks.push_back(lowPrecisionCast);
175+
176+ // if (hasFillOp) {
177+ // auto removeReduncantFill =
178+ // [&](RewriterBase &rewriter, Location loc,
179+ // const linalg::ForallReductionTilingResult &result)
180+ // -> FailureOr<linalg::LinalgOp> {
181+ // auto initValue = result.initialValues;
182+ // if (initValue.size() == 1 &&
183+ // isa<linalg::FillOp>(initValue[0].getDefiningOp())) {
184+ // rewriter.replaceOp(initValue[0].getDefiningOp(),
185+ // dyn_cast<DestinationStyleOpInterface>(
186+ // initValue[0].getDefiningOp())
187+ // .getDpsInits()[0]);
188+ // }
189+ // return dyn_cast<linalg::LinalgOp>(result.parallelTiledOp);
190+ // };
191+ // option.finalReduceCallBacks.push_back(removeReduncantFill);
192+ // }
193+ // return generateOuterLoop(rewriter, linalgOp, option);
194+ // }
195+
196+ LogicalResult matchAndRewrite (linalgx::ScaledDotProductAttentionOp sdpaOp,
53197 PatternRewriter &rewriter) const override {
54- if (!llvm::isa<linalgx::ScaledDotProductAttentionOp>(linalgOp))
55- return failure ();
56- if (linalgOp.hasPureBufferSemantics ())
57- return failure ();
198+ auto decomposableOp =
199+ dyn_cast<mlir::linalg::AggregatedOpInterface>(sdpaOp.getOperation ());
200+ FailureOr<SmallVector<Value>> maybeNewResults =
201+ decomposableOp.decomposeOperation (rewriter);
202+ rewriter.replaceOp (decomposableOp, *maybeNewResults);
203+ return success ();
58204 }
59205};
60206
@@ -65,19 +211,7 @@ struct FlashAttentionConversion
65211 auto &ctx = getContext ();
66212 IRRewriter rewriter (&ctx);
67213 RewritePatternSet patterns (&ctx);
68-
69214 patterns.add <MHAToFlashAttention>(patterns.getContext ());
70- // linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
71- // linalg::ControlDropUnitDims options;
72- // options.rankReductionStrategy =
73- // linalg::ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
74- // linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
75- // tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
76-
77- // for (auto *dialect : ctx.getLoadedDialects())
78- // dialect->getCanonicalizationPatterns(patterns);
79- // for (RegisteredOperationName op : ctx.getRegisteredOperations())
80- // op.getCanonicalizationPatterns(patterns, &ctx);
81215 if (failed (applyPatternsAndFoldGreedily (getOperation (),
82216 std::move (patterns)))) {
83217 return signalPassFailure ();
0 commit comments