Skip to content

Optimize aten::min/max.dim with TopK op#2780

Open
danielhumanmod wants to merge 8 commits intomicrosoft:mainfrom
danielhumanmod:optimize-max-dim
Open

Optimize aten::min/max.dim with TopK op#2780
danielhumanmod wants to merge 8 commits intomicrosoft:mainfrom
danielhumanmod:optimize-max-dim

Conversation

@danielhumanmod
Copy link

Fix pytorch/pytorch#76344

Context

As mentioned in the issue, torch.max(dim=...) can be optimized with TopK to replace the current ReduceMax and ArgMax implementation. This optimization reduces redundant input scans and avoids potential performance overhead in certain execution providers (e.g., ONNX Runtime CUDA EP microsoft/onnxruntime#11348).

In additional, given the torch.min(dim=...) has the similar pattern with max, I also apply this optimization to it.

Verification

Successfully passed existing OpInfo consistency tests:

  • pytest tests/function_libs/torch_lib/ops_test.py
  • pytest tests/function_libs/torch_lib/e2e_ops_tests.py

@danielhumanmod
Copy link
Author

@danielhumanmod please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.

@microsoft-github-policy-service agree [company="{your company}"]

Options:

  • (default - no company specified) I have sole ownership of intellectual property rights to my Submissions and I am not making Submissions in the course of work for my employer.
@microsoft-github-policy-service agree
  • (when company given) I am making Submissions in the course of work for my employer (or my employer has intellectual property rights in my Submissions by contract or applicable law). I have permission from my employer to make Submissions and enter into this Agreement on behalf of my employer. By signing below, the defined term “You” includes me and my employer.
@microsoft-github-policy-service agree company="Microsoft"

Contributor License Agreement

@microsoft-github-policy-service agree

@codecov
Copy link

codecov bot commented Jan 25, 2026

Codecov Report

❌ Patch coverage is 96.59574% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.74%. Comparing base (e06dd92) to head (264aed2).
⚠️ Report is 16 commits behind head on main.

Files with missing lines Patch % Lines
.../rewriter/rules/common/_fuse_reduce_arg_to_topk.py 92.59% 4 Missing and 2 partials ⚠️
...iter/rules/common/_fuse_reduce_arg_to_topk_test.py 98.70% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2780      +/-   ##
==========================================
+ Coverage   70.46%   70.74%   +0.27%     
==========================================
  Files         228      230       +2     
  Lines       27258    27349      +91     
  Branches     2761     2744      -17     
==========================================
+ Hits        19208    19348     +140     
+ Misses       7100     7067      -33     
+ Partials      950      934      -16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.

@github-project-automation github-project-automation bot moved this from Todo to In Progress in ONNX Script Review Board Jan 25, 2026
@danielhumanmod
Copy link
Author

Thanks for creating the PR. Reading it again it seems like topk is more general than ReduceMax and ArgMax. From a node count perspective this may be fewer nodes, but I wonder if the original is easier to optimize with.

Thanks so much for the review! That is a great point, I took some time to dig into the ONNX Runtime implementations to see how they handle this.

  1. From ONNX runtime perspective,

    1. CPU EP provide a fastline when k = 1, which performs a simple linear scan. So on CPU, it seems to behave identically to a fused max+argmax.
    2. CUDA EP will walk through the whole Bitonic/Radix sort process, which can involve more complex instructions. But the upside is that these operations happen primarily in shared memory.
  2. PyTorch Inductor (as an reference): it adopts a similar approach—splitting into reduce_max/arg_max in IR—but leaves it to the runtime (Scheduler) to fuse them. However, when I checked ONNX Runtime, it didn't seem to have an optimization rule to automatically fuse ReduceMax and ArgMax, which implies the split approach effectively incurs one more IO pass compared to TopK

So to the best of my knowledge, TopK might brings more instruction overhead but with less IO. I would appreciate your thoughts here—which approach aligns more with the community's needs? I am flexible to pivot to other tasks if we want to keep the original implementation.

@justinchuby
Copy link
Collaborator

I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass)

@danielhumanmod
Copy link
Author

I am not exactly sure what the actual usage of this operator looks like. Are the two outputs always used? One can imagine that if the second output is unused at all, computing it would be a waste of effort. I wonder if it would make sense for you to contribute a rewrite rule to https://github.com/microsoft/onnxscript/tree/main/onnxscript/rewriter/rules ? This way we can do fusion only when the two outputs are used (if not the second output will be removed by the dead code elimination pass)

Yeah, that's a good point. It makes more sense to handle this in the rewriter/optimizer. I will take a look at the rules and follow up. Thanks for the feedback!

@danielhumanmod
Copy link
Author

Hey @justinchuby ,I’ve added a new rewrite rule to optimize this case based on our previous discussion. Whenever you have a moment, I’d appreciate your thoughts on it. Thanks!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new ONNXScript rewriter rule to fuse Reduce{Max,Min} + Arg{Max,Min} patterns into a single TopK (plus optional Squeeze), aiming to improve performance for torch.min/max(dim=...)-style graphs.

Changes:

  • Introduces FuseReduce{Max,Min}Arg{Max,Min}ToTopK rewrite rules and a RewriteRuleSet.
  • Adds extensive unit tests covering success and failure conditions across opset 13 and 18.
  • Validates numerical equivalence and serialized-model correctness for rewritten graphs.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py Implements the Reduce+Arg → TopK fusion rules for both max and min cases.
onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py Adds unit tests for the new fusion rules, including opset and attribute/input variants.

)

# Step 3: Get axes from Reduce operation
# In opset 18+, axes is an input; in opset 13-17, it's an attribute
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we would be interested in only supporting opset 18+ here to reduce the complexity? (we have version converter) It's just the matter whether we see the rule will be applied standalone or not I guess?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to remove, I see this rule should be mostly used in pipeline, thanks for the suggestion!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only opset 18+ is fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @danielhumanmod Can you add a NOTE/comment somewhere that says the rule is only for opset 18+. Since now it's not for default rewrite rules, it could be used standalone for other users.


# Step 7: Normalize axes if rank is known (handle negative indices)
input_x = reduce_node.inputs[0]
rank = len(input_x.shape) if input_x.shape is not None else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if symbolic shape could work on this case? @justinchuby

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping none of shape means this does not support dynamic at the moment. But symbolic inference should be able to handle the eq

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh actually it is a very good catch, I will use shape.rank() instead to ensure to support both static and symbolic shape, thanks a bunch!

@titaiwangms
Copy link
Contributor

You will have to enable it here:

_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (

@justinchuby
Copy link
Collaborator

justinchuby commented Feb 7, 2026

You will have to enable it here:

_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = (

I don’t think we want to enable this by default. It is unclear if this is generally more performant. @danielhumanmod you may simply expose the rule in https://github.com/microsoft/onnxscript/blob/main/onnxscript/rewriter/rules/common/__init__.py

@danielhumanmod
Copy link
Author

Hey team I solved all the pending comments, appreciate if you could take another look when you have time, thanks! cc @justinchuby @titaiwangms

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. an unblocking comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: In Progress

Development

Successfully merging this pull request may close these issues.

[ONNX] Use topk to export max(dim,keepdim) to onnx

3 participants