Change rock::invertTransform to use FailureOr<T>#2206
Change rock::invertTransform to use FailureOr<T>#2206Mr-Anyone merged 22 commits intoROCm:developfrom
rock::invertTransform to use FailureOr<T>#2206Conversation
There was a problem hiding this comment.
Pull request overview
This pull request changes the signature of rock::invertTransforms to return FailureOr<ArrayAttr> instead of ArrayAttr, making it explicit that the function can fail. The change updates the function declaration, implementation, and all call sites throughout the codebase.
Changes:
- Modified
invertTransformsto returnFailureOr<ArrayAttr>instead ofArrayAttrornullptr - Updated failure handling to return
LogicalResult::failure()instead ofnullptr - Updated all call sites to check for success/failure and extract values using
.value()
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
| transformMapUtils.h | Updated function signature to return FailureOr<ArrayAttr> |
| transformMapUtils.cpp | Changed return type and failure handling in implementation |
| ThreadwiseGemmLowering.cpp | Updated call site with proper failure checking |
| ShuffleGemmForReductions.cpp | Updated call sites with mixed error handling patterns |
| RemoveOutputAlloc.cpp | Updated call site with proper failure checking |
| GridwiseGemmToBlockwise.cpp | Updated multiple call sites, some with unchecked .value() calls |
| BlockwiseLoadTileToThreadwise.cpp | Updated call sites with unchecked .value() calls |
| BlockwiseGemmToThreadwise.cpp | Updated call site with unchecked .value() call |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ArrayAttr inputThreadSubTile2dViewInv = | ||
| invertTransforms(rewriter, loc, inputThreadSubTile2dView); | ||
| invertTransforms(rewriter, loc, inputThreadSubTile2dView).value(); |
There was a problem hiding this comment.
Calling .value() on the result of invertTransforms without checking if the operation succeeded. If invertTransforms fails, this will cause undefined behavior. Add a failure check before calling .value() and return failure if the inversion fails.
There was a problem hiding this comment.
We want to do something similar to what you did in ThreadwiseGemmLowering.cpp by calling something like succeeded on the value from invertTransforms.
| TransformMapAttr invertedTrMap = invertTransformMap(b, trMap, loc); | ||
| if (!invertedTrMap) | ||
| return nullptr; | ||
| return LogicalResult::failure(); |
There was a problem hiding this comment.
You can just return failure() here. No need for the LogicalResult part.
| ArrayAttr inputThreadSubTile2dViewInv = | ||
| invertTransforms(rewriter, loc, inputThreadSubTile2dView); | ||
| invertTransforms(rewriter, loc, inputThreadSubTile2dView).value(); |
There was a problem hiding this comment.
We want to do something similar to what you did in ThreadwiseGemmLowering.cpp by calling something like succeeded on the value from invertTransforms.
| FailureOr<ArrayAttr> inBufferViewsTrAttr = | ||
| invertTransforms(b, loc, inBufferViewsTr.threadSubTile); | ||
| if (failed(inBufferViewsTrAttr)) { | ||
| return failure(); |
There was a problem hiding this comment.
Avoid this at all costs. If something fails, we should always inform the user about what failed. return failure(); will just give the user the generic message "Lowering failed." which does not help to understand what failed.
You can use something like:
return op.emitError("invertTransforms failed");
Same goes for the rest of places in this PR where we check if invertTransforms failed.
There was a problem hiding this comment.
I guess the question now is if it is ok to just return failure in storeGemmInputTile? It is in GridwiseGemmToBlockwise.cpp line 800.
|
|
||
| FailureOr<ArrayAttr> invertedThreadSubTileViews = | ||
| invertTransforms(rewriter, loc, gemm1OutSubTileViewsTr.threadSubTile); | ||
| if (succeeded(invertedThreadSubTileViews)) { |
There was a problem hiding this comment.
Shouldn't we fail here if invertTransforms failed?
There was a problem hiding this comment.
I think in this case we can exit and emit diagnostics early.
| invertTransforms(rewriter, rOp.getLoc(), additionalOutputViews); | ||
| for (Attribute trMap : invertedOutViews) { | ||
| if (failed(maybeInvertedOutViews)) { | ||
| return rOp.emitError("invertTransforms failed"); |
There was a problem hiding this comment.
To my understanding, this is where we want to attach an error message?
| FailureOr<ArrayAttr> inBufferViewsTrAttr = | ||
| invertTransforms(b, loc, inBufferViewsTr.threadSubTile); | ||
| if (failed(inBufferViewsTrAttr)) { | ||
| return failure(); |
There was a problem hiding this comment.
I guess the question now is if it is ok to just return failure in storeGemmInputTile? It is in GridwiseGemmToBlockwise.cpp line 800.
|
|
||
| FailureOr<ArrayAttr> invertedThreadSubTileViews = | ||
| invertTransforms(rewriter, loc, gemm1OutSubTileViewsTr.threadSubTile); | ||
| if (succeeded(invertedThreadSubTileViews)) { |
There was a problem hiding this comment.
I think in this case we can exit and emit diagnostics early.
justinrosner
left a comment
There was a problem hiding this comment.
We also want some LIT tests for this change if possible.
| FailureOr<ArrayAttr> maybeInputThreadSubTile2dViewInv = | ||
| invertTransforms(rewriter, loc, inputThreadSubTile2dView); | ||
| assert(succeeded(maybeInputThreadSubTile2dViewInv) && | ||
| "This must work for partial reduction"); |
There was a problem hiding this comment.
I think a comment that includes something like transforms must be invertible would make the error message more useful to users.
| FailureOr<ArrayAttr> maybeInBufferViewsTrAttr = | ||
| invertTransforms(b, loc, inBufferViewsTr.threadSubTile); | ||
| if (failed(maybeInBufferViewsTrAttr)) { | ||
| return op.emitError("invertTransforms failed"); |
There was a problem hiding this comment.
It might be worth it to make the comments here, and everywhere else below, more explicit to what transforms failed to invert.
| FailureOr<ArrayAttr> maybeInvertedViews = | ||
| invertTransforms(rewriter, rOp.getLoc(), views); | ||
| if (failed(maybeInvertedViews) || maybeInvertedViews.value().empty()) { | ||
| LLVM_DEBUG(llvm::dbgs() << "gemm to reduce view inversion failed.\n"); |
There was a problem hiding this comment.
Change this debug into a op.EmitError()?
justinrosner
left a comment
There was a problem hiding this comment.
Small nit, but otherwise looks good!
| } | ||
|
|
||
| if (!gemmOutToLinalgMaps.empty()) { | ||
| if (!maybeGemmOutToLinalgMaps.value().empty()) { |
There was a problem hiding this comment.
nit: Maybe save maybeGemmOutToLinalgMaps.value() to a new variable since it's going to be used in multiple places.
pabloantoniom
left a comment
There was a problem hiding this comment.
LGTM, just make sure to run clang-format
Motivation
Change the signature of
rock::invertTransformso that the user knows it can fail.Technical Details
Refactor and change code to use
rock::invertTransform.Resolves https://github.com/ROCm/rocMLIR-internal/issues/1999
Test Plan
Build and check CI.
Test Result
Build locally.
Submission Checklist