Skip to content

Remove redundant casts in LLVMIR#2202

Merged
justinrosner merged 26 commits intodevelopfrom
1932-remove-casts
Apr 2, 2026
Merged

Remove redundant casts in LLVMIR#2202
justinrosner merged 26 commits intodevelopfrom
1932-remove-casts

Conversation

@justinrosner
Copy link
Copy Markdown
Contributor

@justinrosner justinrosner commented Jan 8, 2026

Motivation

When processing mixed-precision computations (e.g., attention kernels with f32 intermediate values stored as f16), the generated IR often contains redundant precision conversion patterns:

%wide = ...                           ; f32 computation result
%narrow = llvm.fptrunc %wide : f32 to f16
llvm.store %narrow, %narrow_buf       ; store truncated value
...
%loaded = llvm.load %narrow_buf       ; load truncated value  
%extended = llvm.fpext %loaded : f16 to f32  ; extend back to f32

This pattern causes unnecessary precision loss compared to just keeping the original wide value. This pass eliminates these redundant casts by redirectoing loads to read from a parallel wide buffer when possible.

This implements: https://github.com/ROCm/rocMLIR-internal/issues/1932

Technical Details

This PR introduces the RemoveRedundantCasts pass that operates at the LLVMIR dialect level to optimize fptrunc -> store -> load -> fpext patterns.

General Algorithm:

  1. Find all fptrunc -> store patterns in the function. For each pattern, record whether there's already a parallel store of the wide value to a separate buffer.
  2. Find all load -> fpext patterns where the load is from a buffer that has fptrunc stores.
  3. Verify safety for each load+fpext pattern:
    • All stores to the narrow buffer must be from tracked fptrunc patterns (i.e., no untracked stores that could write different values)
    • All tracked stores must dominate the load
    • The narrow buffer must be an alloca
  4. For safe patterns, create a wide buffer and the corresponding stores if they don't exist. If a parallel store already exists, reuse it:
    • Create a wide alloca right after the narrow alloca
    • For each fptrunc store, insert a store of the wide value to the wide buffer (right after the narrow store, using the same indices)
  5. Apply the transformation:
    • Redirect the load to read from the wide buffer instead
    • Replace uses of the fpext result with the wide load result
    • Delete the fpext (and the old load/GEP if unused)
  6. Clean up unused narrow buffer operations:
    • If the narrow buffer has no remaining uses, erase the fptrunc stores
      • These can only be erased if they are not used by any other operations
    • Erase the narrow alloca if it has no remaining uses

Test Plan

Test Result

  • Nightly CI

Submission Checklist

@justinrosner justinrosner mentioned this pull request Jan 8, 2026
3 tasks
%7 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16>
llvm.store %7, %6 : vector<4xf16>, !llvm.ptr<5>
%8 = llvm.getelementptr %2[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16
%9 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are there so many repeated llvm.fptrunc %1? I don't understand this.
Wouldn't it be easier to do this earlier? for example handling arith.extf, etc?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fptrunc's seem to be the result of loop unrolling. They are all writing to the same buffer. I was doing this earlier and there are quite a few more difficulties with moving this pass somewhere right after GridwiseGemmToBlockwise. The dominance analysis (used for safety) becomes tricky because it doesn't work well when the trunc/ext ops are in different regions, and also having to rewrite the linalg generic makes things more difficult as well.

Both approaches have their pros and cons. We can discuss this more in the team meeting, or elsewhere offline.

Copy link
Copy Markdown
Contributor Author

@justinrosner justinrosner Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main problem I see right now is that either approach that we take is going to have to be completely rewritten for rocmlirTriton. The good news is that re-implementing this pass in rocmlirTriton should be almost trivial because we are removing the complexity of linalg.generics.

Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp Outdated
@justinrosner justinrosner changed the title [WIP] Remove redundant casts in LLVMIR Remove redundant casts in LLVMIR Mar 11, 2026
@justinrosner justinrosner marked this pull request as ready for review March 11, 2026 20:46
@justinrosner justinrosner requested a review from causten as a code owner March 11, 2026 20:46
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new Rock LLVMIR-dialect optimization pass to eliminate redundant fptrunc -> store -> load -> fpext round-trips by introducing/reusing a parallel wide buffer, and wires it into the backend pipeline. Updates the CPU attention MLIR generator to better match the post-pass GPU behavior, and adds regression/e2e tests plus pipeline expectation updates.

Changes:

  • Introduce rock-remove-redundant-casts LLVMIR pass and add it to the Rock backend pipeline.
  • Update rocmlir-gen CPU attention kernel generation to promote narrow-float intermediates to f32 and insert necessary casts.
  • Add LLVMIR-level and end-to-end tests; update pipeline test expectations.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp New LLVMIR pass implementing redundant cast elimination via parallel wide buffers
mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt Build integration for the new transform source
mlir/include/mlir/Dialect/Rock/Passes.td Pass definition for rock-remove-redundant-casts
mlir/include/mlir/Dialect/Rock/Passes.h Generated pass declaration hook-up
mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp Inserts the new pass into the backend LLVM func pipeline
mlir/tools/rocmlir-gen/rocmlir-gen.cpp Adjusts CPU attention kernel typing/casting to align with new optimization behavior
mlir/test/rocmlir-driver/pipelines.mlir Updates expected pipeline printout to include the new pass
mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir New LLVMIR-level regression tests for safe/unsafe patterns + cleanup
mlir/test/fusion/pr-e2e/rock-attention-redundant-casts.mlir New e2e attention test asserting absence of llvm.fpext after the pass
mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir New e2e MIGraphX-style test asserting absence of llvm.fpext after the pass

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp
Comment thread mlir/lib/ExecutionEngine/conv-validation-wrappers.cpp
Comment thread mlir/lib/ExecutionEngine/conv-validation-wrappers.cpp
Comment thread mlir/tools/rocmlir-gen/rocmlir-gen.cpp
Comment thread mlir/tools/rocmlir-gen/rocmlir-gen.cpp Outdated
Copy link
Copy Markdown
Member

@umangyadav umangyadav left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work !

Comment thread mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir
Comment thread mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir
Comment thread mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir Outdated
Comment thread mlir/test/fusion/pr-e2e/rock-attention-redundant-casts.mlir
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp Outdated
Comment thread mlir/tools/rocmlir-gen/rocmlir-gen.cpp Outdated
Comment on lines +4499 to +4504
// Cast result tensor back to the original output element type if we
// promoted intermediate computation to f32 (for narrow float types).
Type resultElemType =
cast<ShapedType>(resultTensor.getType()).getElementType();
Type outputElemType = cast<ShapedType>(outputType).getElementType();
if (resultElemType != outputElemType) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not touching second gemm output type so it should match i think. Why is this change necessary?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the promotion fires (!hasAttnScale && !hasAttnBias), resultOutElementType = f32, so the second GEMM outputs f32. But the output buffer is still in the original type (f16/bf16)

}
sumAbsDiff += static_cast<double>(absDiff);
// Update maxRelDiff only if cpuVal != 0
// Compute relDiff only for elements with meaningful absolute error.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if CPU is also eliminating trunc -> fpext pattern by changing first gemm output tyep then why this is necessary ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because our RemoveRedundantCasts pass is conservative by nature, there will be some instances where this doesn't trigger for the GPU case

Comment on lines +4751 to +4754
bool absDiffExplicit = absDiffThreshold.getNumOccurrences() > 0;
auto absDiffGateVal =
arith::ConstantIntOp::create(b, loc, boolType, absDiffExplicit);

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should hold off from this change in this PR if it is not necessary

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was seeing lots of failures in the nightly CI without these changes

@justinrosner justinrosner requested a review from umangyadav March 31, 2026 16:54
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp
@justinrosner justinrosner requested a review from umangyadav March 31, 2026 19:43
@justinrosner
Copy link
Copy Markdown
Contributor Author

Nightly CI is passing here: https://ml-ci-internal.amd.com/job/MLIR/job/mlir/job/PR-2202/24/pipeline-overview/, will kick off a PR CI to make sure everything is working as expected and then merge.

@justinrosner justinrosner merged commit 281add9 into develop Apr 2, 2026
8 of 15 checks passed
@justinrosner justinrosner deleted the 1932-remove-casts branch April 2, 2026 13:57
umangyadav pushed a commit that referenced this pull request Apr 9, 2026
* Initial truncf finding

* Add in logic so that we are only finding truncfs with direct stores

* Minor comment and debug message fixes

* Add detection for extf ops

* Partial verification of store/load chains

* Initial attempt at LLVMIR level transformation

* Add E2E test

* More LIT tests

* Add newline

* Clang-format

* Remove some extra lines

* Conservative checks

* Remove dynamic check

* Add additional attention edge case

* Update rocmlir-gen intermediate results

* Attend to review comments

* Small fix

* Try to improve on relDiff calculations

* Additional LIT tests

* Fixup some comments

* Attend to more review comments

* Clang-format

* Attend to more review comments

* Add sub-byte check

* Fix test ranges
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants