Skip to content

Fix element type mismatch in attention preSoftmax fusion#2211

Merged
justinrosner merged 6 commits intodevelopfrom
justinr-linalg-fix
Jan 29, 2026
Merged

Fix element type mismatch in attention preSoftmax fusion#2211
justinrosner merged 6 commits intodevelopfrom
justinr-linalg-fix

Conversation

@justinrosner
Copy link
Copy Markdown
Contributor

@justinrosner justinrosner commented Jan 20, 2026

Motivation

This PR fixes a crash that MIGraphX was seeing when compiling an attention kernel with fusion: https://amd-hub.atlassian.net/browse/AIROCMLIR-438

Technical Details

When lowering gridwise_attention_accel ops with preSoftmax fusion, the gemm0 output buffer element type was unconditionally set to elemTypeV (the values input element type). This caused a type mismatch when the preSoftmax body's linalg.generic expected a different element type for it's gemm0 based input (e.g., when the linalg.generic was truncating/extending).

Test Plan

  • PR CI
  • Original kernel from MIGraphX is passing (and turned into a LIT test)

Test Result

  • PR CI

Submission Checklist

@justinrosner justinrosner requested a review from causten as a code owner January 20, 2026 13:44
Copilot AI review requested due to automatic review settings January 20, 2026 13:44
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a crash in MIGraphX when compiling attention kernels with preSoftmax fusion. The issue occurred when lowering gridwise_attention_accel operations where the gemm0 output buffer element type was incorrectly set to the values input element type (elemTypeV), causing a type mismatch when the preSoftmax body's linalg.generic operation expected a different element type (e.g., when truncating or extending).

Changes:

  • Modified element type determination logic to walk the preSoftmax body and extract the correct type from the first linalg.generic operation's gemm0-based input
  • Added a comprehensive LIT test that reproduces the original MIGraphX failure scenario with type conversions in the preSoftmax fusion body

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Added logic to walk preSoftmax body and determine gemmOutElemType from the first generic's gemm0-based input, and fusionOutElemType from the last generic's output, fixing the element type mismatch
mlir/test/Dialect/Rock/gridwise-gemm-linalg-failure.mlir New test file verifying correct handling of attention operations with preSoftmax fusion that performs f16 to f32 extension, ensuring the lowering produces correct buffer types

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/test/Dialect/Rock/gridwise-gemm-input-fusion-type-change.mlir
Copy link
Copy Markdown
Contributor

@pabloantoniom pabloantoniom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second both of Daniel's comments

Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Outdated
Comment thread mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp
@justinrosner justinrosner merged commit ffbacc7 into develop Jan 29, 2026
9 of 16 checks passed
@justinrosner justinrosner deleted the justinr-linalg-fix branch January 29, 2026 17:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants