Skip to content

Change rock::invertTransform to use FailureOr<T>#2206

Merged
Mr-Anyone merged 22 commits intoROCm:developfrom
Mr-Anyone:pr-1999
Jan 23, 2026
Merged

Change rock::invertTransform to use FailureOr<T>#2206
Mr-Anyone merged 22 commits intoROCm:developfrom
Mr-Anyone:pr-1999

Conversation

@Mr-Anyone
Copy link
Copy Markdown
Contributor

Motivation

Change the signature of rock::invertTransform so that the user knows it can fail.

Technical Details

Refactor and change code to use rock::invertTransform.

Resolves https://github.com/ROCm/rocMLIR-internal/issues/1999

Test Plan

Build and check CI.

Test Result

Build locally.

Submission Checklist

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request changes the signature of rock::invertTransforms to return FailureOr<ArrayAttr> instead of ArrayAttr, making it explicit that the function can fail. The change updates the function declaration, implementation, and all call sites throughout the codebase.

Changes:

  • Modified invertTransforms to return FailureOr<ArrayAttr> instead of ArrayAttr or nullptr
  • Updated failure handling to return LogicalResult::failure() instead of nullptr
  • Updated all call sites to check for success/failure and extract values using .value()

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
transformMapUtils.h Updated function signature to return FailureOr<ArrayAttr>
transformMapUtils.cpp Changed return type and failure handling in implementation
ThreadwiseGemmLowering.cpp Updated call site with proper failure checking
ShuffleGemmForReductions.cpp Updated call sites with mixed error handling patterns
RemoveOutputAlloc.cpp Updated call site with proper failure checking
GridwiseGemmToBlockwise.cpp Updated multiple call sites, some with unchecked .value() calls
BlockwiseLoadTileToThreadwise.cpp Updated call sites with unchecked .value() calls
BlockwiseGemmToThreadwise.cpp Updated call site with unchecked .value() call

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlir/lib/Dialect/Rock/Transforms/ShuffleGemmForReductions.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp Outdated
Comment on lines +1030 to +1031
ArrayAttr inputThreadSubTile2dViewInv =
invertTransforms(rewriter, loc, inputThreadSubTile2dView);
invertTransforms(rewriter, loc, inputThreadSubTile2dView).value();
Copy link

Copilot AI Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .value() on the result of invertTransforms without checking if the operation succeeded. If invertTransforms fails, this will cause undefined behavior. Add a failure check before calling .value() and return failure if the inversion fails.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to do something similar to what you did in ThreadwiseGemmLowering.cpp by calling something like succeeded on the value from invertTransforms.

TransformMapAttr invertedTrMap = invertTransformMap(b, trMap, loc);
if (!invertedTrMap)
return nullptr;
return LogicalResult::failure();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just return failure() here. No need for the LogicalResult part.

Comment on lines +1030 to +1031
ArrayAttr inputThreadSubTile2dViewInv =
invertTransforms(rewriter, loc, inputThreadSubTile2dView);
invertTransforms(rewriter, loc, inputThreadSubTile2dView).value();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to do something similar to what you did in ThreadwiseGemmLowering.cpp by calling something like succeeded on the value from invertTransforms.

FailureOr<ArrayAttr> inBufferViewsTrAttr =
invertTransforms(b, loc, inBufferViewsTr.threadSubTile);
if (failed(inBufferViewsTrAttr)) {
return failure();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid this at all costs. If something fails, we should always inform the user about what failed. return failure(); will just give the user the generic message "Lowering failed." which does not help to understand what failed.

You can use something like:

return op.emitError("invertTransforms failed");

Same goes for the rest of places in this PR where we check if invertTransforms failed.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the question now is if it is ok to just return failure in storeGemmInputTile? It is in GridwiseGemmToBlockwise.cpp line 800.


FailureOr<ArrayAttr> invertedThreadSubTileViews =
invertTransforms(rewriter, loc, gemm1OutSubTileViewsTr.threadSubTile);
if (succeeded(invertedThreadSubTileViews)) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we fail here if invertTransforms failed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this case we can exit and emit diagnostics early.

invertTransforms(rewriter, rOp.getLoc(), additionalOutputViews);
for (Attribute trMap : invertedOutViews) {
if (failed(maybeInvertedOutViews)) {
return rOp.emitError("invertTransforms failed");
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my understanding, this is where we want to attach an error message?

FailureOr<ArrayAttr> inBufferViewsTrAttr =
invertTransforms(b, loc, inBufferViewsTr.threadSubTile);
if (failed(inBufferViewsTrAttr)) {
return failure();
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the question now is if it is ok to just return failure in storeGemmInputTile? It is in GridwiseGemmToBlockwise.cpp line 800.


FailureOr<ArrayAttr> invertedThreadSubTileViews =
invertTransforms(rewriter, loc, gemm1OutSubTileViewsTr.threadSubTile);
if (succeeded(invertedThreadSubTileViews)) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in this case we can exit and emit diagnostics early.

Copy link
Copy Markdown
Contributor

@justinrosner justinrosner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also want some LIT tests for this change if possible.

FailureOr<ArrayAttr> maybeInputThreadSubTile2dViewInv =
invertTransforms(rewriter, loc, inputThreadSubTile2dView);
assert(succeeded(maybeInputThreadSubTile2dViewInv) &&
"This must work for partial reduction");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a comment that includes something like transforms must be invertible would make the error message more useful to users.

FailureOr<ArrayAttr> maybeInBufferViewsTrAttr =
invertTransforms(b, loc, inBufferViewsTr.threadSubTile);
if (failed(maybeInBufferViewsTrAttr)) {
return op.emitError("invertTransforms failed");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth it to make the comments here, and everywhere else below, more explicit to what transforms failed to invert.

FailureOr<ArrayAttr> maybeInvertedViews =
invertTransforms(rewriter, rOp.getLoc(), views);
if (failed(maybeInvertedViews) || maybeInvertedViews.value().empty()) {
LLVM_DEBUG(llvm::dbgs() << "gemm to reduce view inversion failed.\n");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this debug into a op.EmitError()?

Copy link
Copy Markdown
Contributor

@justinrosner justinrosner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit, but otherwise looks good!

}

if (!gemmOutToLinalgMaps.empty()) {
if (!maybeGemmOutToLinalgMaps.value().empty()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe save maybeGemmOutToLinalgMaps.value() to a new variable since it's going to be used in multiple places.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Copy link
Copy Markdown
Contributor

@pabloantoniom pabloantoniom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just make sure to run clang-format

@Mr-Anyone Mr-Anyone merged commit 9f84329 into ROCm:develop Jan 23, 2026
7 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants