Optimize aten::min/max.dim with TopK op#2780
Optimize aten::min/max.dim with TopK op#2780danielhumanmod wants to merge 8 commits intomicrosoft:mainfrom
Conversation
@microsoft-github-policy-service agree |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2780 +/- ##
==========================================
+ Coverage 70.46% 70.74% +0.27%
==========================================
Files 228 230 +2
Lines 27258 27349 +91
Branches 2761 2744 -17
==========================================
+ Hits 19208 19348 +140
+ Misses 7100 7067 -33
+ Partials 950 934 -16 ☔ View full report in Codecov by Sentry. |
Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.
So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation. |
|
I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass) |
Yeah, that's a good point. It makes more sense to handle this in the rewriter/optimizer. I will take a look at the rules and follow up. Thanks for the feedback! |
|
Hey @justinchuby ,I’ve added a new rewrite rule to optimize this case based on our previous discussion. Whenever you have a moment, I’d appreciate your thoughts on it. Thanks! |
There was a problem hiding this comment.
Pull request overview
Adds a new ONNXScript rewriter rule to fuse Reduce{Max,Min} + Arg{Max,Min} patterns into a single TopK (plus optional Squeeze), aiming to improve performance for torch.min/max(dim=...)-style graphs.
Changes:
- Introduces
FuseReduce{Max,Min}Arg{Max,Min}ToTopKrewrite rules and aRewriteRuleSet. - Adds extensive unit tests covering success and failure conditions across opset 13 and 18.
- Validates numerical equivalence and serialized-model correctness for rewritten graphs.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py |
Implements the Reduce+Arg → TopK fusion rules for both max and min cases. |
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py |
Adds unit tests for the new fusion rules, including opset and attribute/input variants. |
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py
Outdated
Show resolved
Hide resolved
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py
Outdated
Show resolved
Hide resolved
| ) | ||
|
|
||
| # Step 3: Get axes from Reduce operation | ||
| # In opset 18+, axes is an input; in opset 13-17, it's an attribute |
There was a problem hiding this comment.
I wonder if we would be interested in only supporting opset 18+ here to reduce the complexity? (we have version converter) It's just the matter whether we see the rule will be applied standalone or not I guess?
There was a problem hiding this comment.
That makes sense to remove, I see this rule should be mostly used in pipeline, thanks for the suggestion!
There was a problem hiding this comment.
Only opset 18+ is fine
There was a problem hiding this comment.
Sorry @danielhumanmod Can you add a NOTE/comment somewhere that says the rule is only for opset 18+. Since now it's not for default rewrite rules, it could be used standalone for other users.
|
|
||
| # Step 7: Normalize axes if rank is known (handle negative indices) | ||
| input_x = reduce_node.inputs[0] | ||
| rank = len(input_x.shape) if input_x.shape is not None else None |
There was a problem hiding this comment.
I wonder if symbolic shape could work on this case? @justinchuby
There was a problem hiding this comment.
Skipping none of shape means this does not support dynamic at the moment. But symbolic inference should be able to handle the eq
There was a problem hiding this comment.
Ohh actually it is a very good catch, I will use shape.rank() instead to ensure to support both static and symbolic shape, thanks a bunch!
|
You will have to enable it here: |
I don’t think we want to enable this by default. It is unclear if this is generally more performant. @danielhumanmod you may simply expose the rule in https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/rules/common/__init__.py |
|
Hey team I solved all the pending comments, appreciate if you could take another look when you have time, thanks! cc @justinchuby @titaiwangms |
titaiwangms
left a comment
There was a problem hiding this comment.
Thank you. an unblocking comment
Fix pytorch/pytorch#76344
Context
As mentioned in the issue,
torch.max(dim=...)can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).In additional, given the
torch.min(dim=...)has the similar pattern with max, I also apply this optimization to it.Verification
Successfully passed existing OpInfo consistency tests: