[GlobalOptimization] Add PromoteContractionOutputsPass#23824
[GlobalOptimization] Add PromoteContractionOutputsPass#23824HanKuanChen wants to merge 3 commits intoiree-org:mainfrom
Conversation
This pass promotes contraction operation outputs from lower-precision floating-point types (f16/bf16) to higher-precision types (f32). This is the counterpart of DemoteContractionInputsPass, which handles input type demotion. The following options are also added: - --iree-global-opt-promote-contraction-outputs-type - --iree-global-opt-promote-contraction-outputs-operations Signed-off-by: Han-Kuan Chen <hankuan.chen@sifive.com>
9373a1d to
9eb3fa6
Compare
| llvm::cl::values(clEnumValN(PromoteOperation::All, "all", | ||
| "Promote all contraction ops."), | ||
| clEnumValN(PromoteOperation::Conv, "conv", | ||
| "Only promote convolution ops."), | ||
| clEnumValN(PromoteOperation::Matmul, "matmul", | ||
| "Only promote matmul ops."), | ||
| clEnumValN(PromoteOperation::None, "none", | ||
| "Promote no contraction ops.")), |
There was a problem hiding this comment.
Do we want/need a test for "all" and "none"?
There was a problem hiding this comment.
I don't think we need explicit tests for "all" and "none":
- "all" is just the combination of matmul + conv, already covered by existing tests
- "none" would be a no-op (no changes to the IR)
| Operation *op = linalgOp.getOperation(); | ||
| if (promoteMatmul && IREE::LinalgExt::isPureMatmul(op)) { | ||
| replaceOpOutputs(static_cast<linalg::MatmulOp *>(nullptr)); | ||
| } else if (promoteMatmul && isa<linalg::MatvecOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::MatvecOp *>(nullptr)); | ||
| } else if (promoteMatmul && isa<linalg::VecmatOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::VecmatOp *>(nullptr)); | ||
| } else if (promoteMatmul && IREE::LinalgExt::isPureBatchMatmul(op)) { | ||
| replaceOpOutputs(static_cast<linalg::BatchMatmulOp *>(nullptr)); | ||
| } else if (promoteMatmul && isa<linalg::BatchMatvecOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::BatchMatvecOp *>(nullptr)); | ||
| } else if (promoteMatmul && isa<linalg::BatchVecmatOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::BatchVecmatOp *>(nullptr)); | ||
| } else if (promoteMatmul && isa<linalg::MatmulTransposeAOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::MatmulTransposeAOp *>(nullptr)); | ||
| } else if (promoteMatmul && isa<linalg::MatmulTransposeBOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::MatmulTransposeBOp *>(nullptr)); | ||
| } else if (promoteMatmul && isa<linalg::BatchMatmulTransposeAOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::BatchMatmulTransposeAOp *>(nullptr)); | ||
| } else if (promoteMatmul && isa<linalg::BatchMatmulTransposeBOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::BatchMatmulTransposeBOp *>(nullptr)); | ||
| } else if (promoteConv && isa<linalg::Conv2DOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::Conv2DOp *>(nullptr)); | ||
| } else if (promoteConv && isa<linalg::Conv2DNchwFchwOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::Conv2DNchwFchwOp *>(nullptr)); | ||
| } else if (promoteConv && isa<linalg::Conv2DNhwcHwcfOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::Conv2DNhwcHwcfOp *>(nullptr)); | ||
| } else if (promoteConv && isa<linalg::Conv2DNhwcFhwcOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::Conv2DNhwcFhwcOp *>(nullptr)); | ||
| } else if (promoteConv && isa<linalg::Conv2DNgchwFgchwOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::Conv2DNgchwFgchwOp *>(nullptr)); | ||
| } else if (promoteConv && isa<linalg::Conv2DNgchwGfchwOp>(op)) { | ||
| replaceOpOutputs(static_cast<linalg::Conv2DNgchwGfchwOp *>(nullptr)); | ||
| } else { | ||
| return failure(); | ||
| } | ||
|
|
There was a problem hiding this comment.
Non-blocking: this might not be feasible, but it would be nice if we had a way to either mutate the operation or create a new operation with promoted output type through the linalg::LinalgOp interface so we don't need to list all types. This should be possible for the demotion pass but might be problematic here because we'd need to change the result type.
There was a problem hiding this comment.
Thanks for the suggestion! I agree that a unified interface would be cleaner.
I'll keep the current approach for now since it's non-blocking, but this is good to keep in mind for future improvements.
| @@ -0,0 +1,349 @@ | |||
| // RUN: iree-opt --split-input-file -iree-global-opt-promote-contraction-outputs="type=bf16 operation=matmul" %s | FileCheck %s --check-prefix=F16-MATMUL | |||
There was a problem hiding this comment.
Also, if this was done through the interface we probably wouldn't need as many tests.
Signed-off-by: Han-Kuan Chen <hankuan.chen@sifive.com>
|
ping |
| "Demote no contraction ops.")), | ||
| llvm::cl::init(DemoteOperation::None)); | ||
|
|
||
| static llvm::cl::opt<PromoteType> clPromoteContractionOutputsType( |
There was a problem hiding this comment.
Just FYI, we treat these kind of flags as "developer" flags, and these are not expected to be supported at all. Also they only work with command line interface and do not work with the library interface. So if you are relying on it for a "deployment" use case, then that is inherently unstable.
We obviously wont pull the rug underneath from you, but just letting you know the general expectation of these kind of flags. I have no issues with actually having these flags.
There was a problem hiding this comment.
Thanks for the heads-up. I appreciate the clarification about these being developer flags.
We may use this in deployment. What would be the recommended stable approach?
Signed-off-by: Han-Kuan Chen <hankuan.chen@sifive.com>
|
ping |
This pass promotes contraction operation outputs from lower-precision floating-point types (f16/bf16) to higher-precision types (f32).
This is the counterpart of DemoteContractionInputsPass, which handles input type demotion.
The following options are also added: