Remove redundant casts in LLVMIR#2202
Conversation
| %7 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> | ||
| llvm.store %7, %6 : vector<4xf16>, !llvm.ptr<5> | ||
| %8 = llvm.getelementptr %2[12] : (!llvm.ptr<5>) -> !llvm.ptr<5>, f16 | ||
| %9 = llvm.fptrunc %1 : vector<4xf32> to vector<4xf16> |
There was a problem hiding this comment.
why are there so many repeated llvm.fptrunc %1? I don't understand this.
Wouldn't it be easier to do this earlier? for example handling arith.extf, etc?
There was a problem hiding this comment.
The fptrunc's seem to be the result of loop unrolling. They are all writing to the same buffer. I was doing this earlier and there are quite a few more difficulties with moving this pass somewhere right after GridwiseGemmToBlockwise. The dominance analysis (used for safety) becomes tricky because it doesn't work well when the trunc/ext ops are in different regions, and also having to rewrite the linalg generic makes things more difficult as well.
Both approaches have their pros and cons. We can discuss this more in the team meeting, or elsewhere offline.
There was a problem hiding this comment.
The main problem I see right now is that either approach that we take is going to have to be completely rewritten for rocmlirTriton. The good news is that re-implementing this pass in rocmlirTriton should be almost trivial because we are removing the complexity of linalg.generics.
2463025 to
350c3c7
Compare
There was a problem hiding this comment.
Pull request overview
Adds a new Rock LLVMIR-dialect optimization pass to eliminate redundant fptrunc -> store -> load -> fpext round-trips by introducing/reusing a parallel wide buffer, and wires it into the backend pipeline. Updates the CPU attention MLIR generator to better match the post-pass GPU behavior, and adds regression/e2e tests plus pipeline expectation updates.
Changes:
- Introduce
rock-remove-redundant-castsLLVMIR pass and add it to the Rock backend pipeline. - Update
rocmlir-genCPU attention kernel generation to promote narrow-float intermediates tof32and insert necessary casts. - Add LLVMIR-level and end-to-end tests; update pipeline test expectations.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/lib/Dialect/Rock/Transforms/RemoveRedundantCasts.cpp | New LLVMIR pass implementing redundant cast elimination via parallel wide buffers |
| mlir/lib/Dialect/Rock/Transforms/CMakeLists.txt | Build integration for the new transform source |
| mlir/include/mlir/Dialect/Rock/Passes.td | Pass definition for rock-remove-redundant-casts |
| mlir/include/mlir/Dialect/Rock/Passes.h | Generated pass declaration hook-up |
| mlir/lib/Dialect/Rock/Pipelines/Pipelines.cpp | Inserts the new pass into the backend LLVM func pipeline |
| mlir/tools/rocmlir-gen/rocmlir-gen.cpp | Adjusts CPU attention kernel typing/casting to align with new optimization behavior |
| mlir/test/rocmlir-driver/pipelines.mlir | Updates expected pipeline printout to include the new pass |
| mlir/test/Dialect/LLVMIR/remove-redundant-casts.mlir | New LLVMIR-level regression tests for safe/unsafe patterns + cleanup |
| mlir/test/fusion/pr-e2e/rock-attention-redundant-casts.mlir | New e2e attention test asserting absence of llvm.fpext after the pass |
| mlir/test/fusion/pr-e2e/mixr-remove-redundant-casts.mlir | New e2e MIGraphX-style test asserting absence of llvm.fpext after the pass |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
4e2d2d8 to
197bbc2
Compare
| // Cast result tensor back to the original output element type if we | ||
| // promoted intermediate computation to f32 (for narrow float types). | ||
| Type resultElemType = | ||
| cast<ShapedType>(resultTensor.getType()).getElementType(); | ||
| Type outputElemType = cast<ShapedType>(outputType).getElementType(); | ||
| if (resultElemType != outputElemType) { |
There was a problem hiding this comment.
We are not touching second gemm output type so it should match i think. Why is this change necessary?
There was a problem hiding this comment.
When the promotion fires (!hasAttnScale && !hasAttnBias), resultOutElementType = f32, so the second GEMM outputs f32. But the output buffer is still in the original type (f16/bf16)
| } | ||
| sumAbsDiff += static_cast<double>(absDiff); | ||
| // Update maxRelDiff only if cpuVal != 0 | ||
| // Compute relDiff only for elements with meaningful absolute error. |
There was a problem hiding this comment.
if CPU is also eliminating trunc -> fpext pattern by changing first gemm output tyep then why this is necessary ?
There was a problem hiding this comment.
Because our RemoveRedundantCasts pass is conservative by nature, there will be some instances where this doesn't trigger for the GPU case
| bool absDiffExplicit = absDiffThreshold.getNumOccurrences() > 0; | ||
| auto absDiffGateVal = | ||
| arith::ConstantIntOp::create(b, loc, boolType, absDiffExplicit); | ||
|
|
There was a problem hiding this comment.
I think we should hold off from this change in this PR if it is not necessary
There was a problem hiding this comment.
I was seeing lots of failures in the nightly CI without these changes
a8e3e1b to
3d3f8f7
Compare
|
Nightly CI is passing here: https://ml-ci-internal.amd.com/job/MLIR/job/mlir/job/PR-2202/24/pipeline-overview/, will kick off a PR CI to make sure everything is working as expected and then merge. |
* Initial truncf finding * Add in logic so that we are only finding truncfs with direct stores * Minor comment and debug message fixes * Add detection for extf ops * Partial verification of store/load chains * Initial attempt at LLVMIR level transformation * Add E2E test * More LIT tests * Add newline * Clang-format * Remove some extra lines * Conservative checks * Remove dynamic check * Add additional attention edge case * Update rocmlir-gen intermediate results * Attend to review comments * Small fix * Try to improve on relDiff calculations * Additional LIT tests * Fixup some comments * Attend to more review comments * Clang-format * Attend to more review comments * Add sub-byte check * Fix test ranges
Motivation
When processing mixed-precision computations (e.g., attention kernels with f32 intermediate values stored as f16), the generated IR often contains redundant precision conversion patterns:
This pattern causes unnecessary precision loss compared to just keeping the original wide value. This pass eliminates these redundant casts by redirectoing loads to read from a parallel wide buffer when possible.
This implements: https://github.com/ROCm/rocMLIR-internal/issues/1932
Technical Details
This PR introduces the
RemoveRedundantCastspass that operates at the LLVMIR dialect level to optimize fptrunc -> store -> load -> fpext patterns.General Algorithm:
Test Plan
Test Result
Submission Checklist