From d15e6ac7af7658235dc5eab17552695c6f832690 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Sat, 24 Jan 2026 22:22:31 -0800 Subject: [PATCH 1/9] optimize min/max dim with topk --- .../function_libs/torch_lib/ops/core.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 860b878edb..289ecd9ed2 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6155,9 +6155,13 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I result = self indices = op.Constant(value_int=0) else: - dims = op.Reshape(dim, op.Constant(value_ints=[-1])) - result = op.ReduceMax(self, dims, keepdims=keepdim) - indices = op.ArgMax(self, axis=dim, keepdims=keepdim) + values, indices = op.TopK(self, K=[1], axis=dim, largest=1, sorted=0) + if keepdim: + result = values + else: + squeeze_axe = op.Constant(value_ints=[dim]) + result = op.Squeeze(values, axes=squeeze_axe) + indices = op.Squeeze(indices, axes=squeeze_axe) return result, indices @@ -6242,10 +6246,13 @@ def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, T result = self indices = op.Constant(value_int=0) else: - dims = op.Reshape(dim, op.Constant(value_ints=[-1])) - result = op.ReduceMin(self, dims, keepdims=keepdim) - indices = op.ArgMin(self, axis=dim, keepdims=keepdim) - + values, indices = op.TopK(self, K=[1], axis=dim, largest=0, sorted=0) + if keepdim: + result = values + else: + squeeze_axe = op.Constant(value_ints=[dim]) + result = op.Squeeze(values, axes=squeeze_axe) + indices = op.Squeeze(indices, axes=squeeze_axe) return result, indices From 39baa3829c9726f897acc13cb0b2db24502618de Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Sun, 25 Jan 2026 10:19:37 -0800 Subject: [PATCH 2/9] tests --- .../function_libs/torch_lib/e2e_ops_tests.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 9ee12f3ac3..84a4067b3f 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -914,6 +914,58 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) + def test_max_dim_negative_dim_squeeze_stability(self): + """Ensure max.dim(dim=-1, keepdim=False) exports and runs correctly. + + TopK + Squeeze(axes=[dim]) receives dim=-1. Validates that ORT handles + negative axis in Squeeze and output shape matches PyTorch. + """ + + class Model(torch.nn.Module): + def forward(self, x): + return torch.max(x, dim=-1, keepdim=False) + + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export( + Model(), (x,), dynamo=True, verbose=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_min_dim_negative_dim_squeeze_stability(self): + """Ensure min.dim(dim=-1, keepdim=False) exports and runs correctly. + + Same as max_dim_negative_dim: TopK + Squeeze(axes=[dim]) with dim=-1. + """ + + class Model(torch.nn.Module): + def forward(self, x): + return torch.min(x, dim=-1, keepdim=False) + + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export( + Model(), (x,), dynamo=True, verbose=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_max_dim_chained_reduction(self): + """Ensure x.max(dim=1).values.max(dim=0) exports and runs correctly. + + Validates that TopK -> Squeeze -> next TopK -> Squeeze shape flow + is correct when chaining max.dim calls. + """ + + class Model(torch.nn.Module): + def forward(self, x): + v1, _ = x.max(dim=1, keepdim=False) + v2, _ = v1.max(dim=0, keepdim=False) + return v2 + + x = torch.randn(2, 3, 4) + onnx_program = torch.onnx.export( + Model(), (x,), dynamo=True, verbose=False + ) + _testing.assert_onnx_program(onnx_program) + if __name__ == "__main__": unittest.main() From 312ef9e89debbd566cf0f3ebae12fbef2f845233 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Mon, 2 Feb 2026 18:15:46 -0800 Subject: [PATCH 3/9] rewrite rule to fuse ruduce + arg min/max with topk --- .../function_libs/torch_lib/ops/core.py | 21 +- .../rules/common/_fuse_reduce_arg_to_topk.py | 297 ++++++++++ .../common/_fuse_reduce_arg_to_topk_test.py | 543 ++++++++++++++++++ .../function_libs/torch_lib/e2e_ops_tests.py | 52 -- 4 files changed, 847 insertions(+), 66 deletions(-) create mode 100644 onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py create mode 100644 onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 289ecd9ed2..860b878edb 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6155,13 +6155,9 @@ def aten_max_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, I result = self indices = op.Constant(value_int=0) else: - values, indices = op.TopK(self, K=[1], axis=dim, largest=1, sorted=0) - if keepdim: - result = values - else: - squeeze_axe = op.Constant(value_ints=[dim]) - result = op.Squeeze(values, axes=squeeze_axe) - indices = op.Squeeze(indices, axes=squeeze_axe) + dims = op.Reshape(dim, op.Constant(value_ints=[-1])) + result = op.ReduceMax(self, dims, keepdims=keepdim) + indices = op.ArgMax(self, axis=dim, keepdims=keepdim) return result, indices @@ -6246,13 +6242,10 @@ def aten_min_dim(self: TReal, dim: int, keepdim: bool = False) -> Tuple[TReal, T result = self indices = op.Constant(value_int=0) else: - values, indices = op.TopK(self, K=[1], axis=dim, largest=0, sorted=0) - if keepdim: - result = values - else: - squeeze_axe = op.Constant(value_ints=[dim]) - result = op.Squeeze(values, axes=squeeze_axe) - indices = op.Squeeze(indices, axes=squeeze_axe) + dims = op.Reshape(dim, op.Constant(value_ints=[-1])) + result = op.ReduceMin(self, dims, keepdims=keepdim) + indices = op.ArgMin(self, axis=dim, keepdims=keepdim) + return result, indices diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py new file mode 100644 index 0000000000..f4565663d4 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py @@ -0,0 +1,297 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Fuses Reduce{Max,Min} and Arg{Max,Min} patterns into TopK. + +Supported transformations: +- ReduceMax(X, axes=[axis], keepdims=k) + ArgMax(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=1) [+ Squeeze if k=0] +- ReduceMin(X, axes=[axis], keepdims=k) + ArgMin(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=0) [+ Squeeze if k=0] + +Supports both ONNX opset versions: + - Opset 13-17: Reduce{Max,Min} with axes as an attribute + - Opset 18+: Reduce{Max,Min} with axes as a second input + +Constraints: + - Both nodes must operate on the same input X. + - Both nodes must target the same axis. + - Both nodes must have the same keepdims attribute value. + - The Reduce node must operate on a single axis (len(axes) == 1). + - For opset 18+, the Reduce node's axes input must be a constant. +""" + +from __future__ import annotations + +from abc import abstractmethod + +import numpy as np +import onnx_ir as ir + +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class FuseReduceArgToTopKBase(RewriteRuleClassBase): + """Base class for fusing Reduce{Max,Min} + Arg{Max,Min} into TopK. + + This base class contains the common logic for checking and rewriting patterns where + a Reduce operation and its corresponding Arg operation can be replaced with a single + TopK operation. + + Subclasses must implement: + - pattern(): Define the specific Reduce and Arg operations to match + - reduce_op_type: Property returning the name of the Reduce op (e.g., "ReduceMax") + - arg_op_type: Property returning the name of the Arg op (e.g., "ArgMax") + - largest: Property returning 1 for Max operations, 0 for Min operations + """ + + @property + @abstractmethod + def reduce_op_type(self) -> str: + """Return the name of the Reduce operation""" + ... + + @property + @abstractmethod + def arg_op_type(self) -> str: + """Return the name of the Arg operation""" + ... + + @property + @abstractmethod + def largest(self) -> int: + """Return 1 for Max operations (largest elements), 0 for Min operations (smallest elements).""" + ... + + def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: + """Check if Reduce and Arg operations can be safely fused into TopK. + + Conditions: + - Both nodes must have the same keepdims attribute. + - The Reduce node must operate on a single axis. + - Both nodes must operate on the same axis. + - The Arg node must not use select_last_index=1 (TopK doesn't support this). + + Args: + context: The rewrite context (unused). + reduce_val: The output of the Reduce operation (ReduceMax/ReduceMin). + arg_idx: The output of the Arg operation (ArgMax/ArgMin). + + Returns: + MatchResult: Success if the pattern can be fused, Failure otherwise. + """ + del context + check_result = MatchResult() + + reduce_node = reduce_val.producer() + arg_node = arg_idx.producer() + + if reduce_node is None or arg_node is None: + return check_result.fail("Cannot find producer nodes.") + + # Step 1: Get keepdims attribute from both nodes + reduce_keepdims_attr = reduce_node.attributes.get("keepdims") + arg_keepdims_attr = arg_node.attributes.get("keepdims") + + # ONNX default: keepdims = 1 for both Reduce and Arg operations + reduce_keepdims = ( + reduce_keepdims_attr.as_int() if reduce_keepdims_attr is not None else 1 + ) + arg_keepdims = arg_keepdims_attr.as_int() if arg_keepdims_attr is not None else 1 + + # Step 2: Check if keepdims match + if reduce_keepdims != arg_keepdims: + return check_result.fail( + f"keepdims mismatch: {self.reduce_op_type} has {reduce_keepdims}, " + f"{self.arg_op_type} has {arg_keepdims}." + ) + + # Step 3: Get axes from Reduce operation + # In opset 18+, axes is an input; in opset 13-17, it's an attribute + reduce_axes_attr = reduce_node.attributes.get("axes") + + if reduce_axes_attr is not None: + # Opset 13-17: axes is an attribute + try: + axes_list = list(reduce_axes_attr.as_ints()) + except Exception: + return check_result.fail(f"Cannot parse {self.reduce_op_type} axes attribute.") + elif len(reduce_node.inputs) >= 2 and reduce_node.inputs[1] is not None: + # Opset 18+: axes is the second input + axes_input = reduce_node.inputs[1] + axes_const_value = axes_input.const_value + if axes_const_value is None: + return check_result.fail( + f"{self.reduce_op_type} axes input is not a constant." + ) + try: + axes_array = axes_const_value.numpy() + axes_list = axes_array.tolist() if axes_array.ndim > 0 else [int(axes_array)] + except Exception: + return check_result.fail(f"Cannot parse {self.reduce_op_type} axes input.") + else: + return check_result.fail( + f"{self.reduce_op_type} axes not found (neither attribute nor input)." + ) + + # Step 4: Check that Reduce operates on exactly one axis + if len(axes_list) != 1: + return check_result.fail( + f"{self.reduce_op_type} must operate on a single axis, got {len(axes_list)} axes." + ) + + reduce_axis = axes_list[0] + + # Step 5: Get axis from Arg operation + # ONNX default: axis = 0 for ArgMax/ArgMin + arg_axis_attr = arg_node.attributes.get("axis") + arg_axis = arg_axis_attr.as_int() if arg_axis_attr is not None else 0 + + # Step 6: Check select_last_index attribute (if present) + # TopK always returns the first occurrence in case of ties + select_last_index_attr = arg_node.attributes.get("select_last_index") + if select_last_index_attr is not None and select_last_index_attr.as_int() != 0: + return check_result.fail( + f"{self.arg_op_type} has select_last_index=1, which is not supported by TopK." + ) + + # Step 7: Normalize axes if rank is known (handle negative indices) + input_x = reduce_node.inputs[0] + rank = len(input_x.shape) if input_x.shape is not None else None + + def normalize_axis(axis: int, rank: int | None) -> int: + if rank is not None and axis < 0: + return axis + rank + return axis + + if normalize_axis(reduce_axis, rank) != normalize_axis(arg_axis, rank): + return check_result.fail( + f"Axis mismatch: {self.reduce_op_type} operates on axis {reduce_axis}, " + f"{self.arg_op_type} operates on axis {arg_axis}." + ) + + return check_result + + def rewrite(self, op, x, reduce_val, arg_idx): + """Rewrite the matched pattern with TopK (and optionally Squeeze). + + Args: + op: The operation builder. + x: The input to both Reduce and Arg operations. + reduce_val: The output of the Reduce operation. + arg_idx: The output of the Arg operation. + + Returns: + Tuple of (values, indices) matching the original outputs. + """ + # Step 1: Get the nodes + arg_node = arg_idx.producer() + + # Step 2: Extract necessary attributes with ONNX default values + axis_attr = arg_node.attributes.get("axis") + keepdims_attr = arg_node.attributes.get("keepdims") + + axis = axis_attr.as_int() if axis_attr is not None else 0 + keepdims = keepdims_attr.as_int() if keepdims_attr is not None else 1 + + # Step 2b: Normalize axis (convert negative to positive) if rank is known + if axis < 0 and x.shape is not None: + axis = len(x.shape) + axis + + # Step 3: Create K constant + k_constant = op.Constant(value=ir.tensor(np.array([1], dtype=np.int64))) + + # Step 4: Create TopK node + topk_values, topk_indices = op.TopK( + x, + k_constant, + axis=axis, + largest=self.largest, + sorted=1, + _outputs=2, + ) + + # Step 5: Handle keepdims=0 case + if keepdims == 0: + # TopK output always keeps the dimension (just makes it size 1) + # We need to squeeze it to match the original Reduce/Arg behavior + axes_constant = op.Constant(value=ir.tensor(np.array([axis], dtype=np.int64))) + + new_values = op.Squeeze(topk_values, axes_constant) + new_indices = op.Squeeze(topk_indices, axes_constant) + else: + new_values = topk_values + new_indices = topk_indices + + return new_values, new_indices + + +class FuseReduceMaxArgMaxToTopK(FuseReduceArgToTopKBase): + """Replaces ReduceMax + ArgMax with TopK(largest=1). + + Transformation: + ReduceMax(X, axes=[axis], keepdims=k) + ArgMax(X, axis=axis, keepdims=k) + → TopK(X, k=1, axis=axis, largest=1) [+ Squeeze if k=0] + + When keepdims=0, the output of TopK is squeezed to match the original output shapes. + """ + + @property + def reduce_op_type(self) -> str: + return "ReduceMax" + + + @property + def arg_op_type(self) -> str: + return "ArgMax" + + @property + def largest(self) -> int: + return 1 # TopK returns largest elements + + def pattern(self, op, x): + """Define the pattern to match: ReduceMax and ArgMax on the same input. + + Note: For opset 18+, ReduceMax has a second input for axes, which we allow + but will validate in check() to ensure it's a constant. + """ + reduce_val = op.ReduceMax(x, _allow_other_inputs=True, _outputs=["reduce_val"]) + arg_idx = op.ArgMax(x, _outputs=["arg_idx"]) + return reduce_val, arg_idx + + +class FuseReduceMinArgMinToTopK(FuseReduceArgToTopKBase): + """Replaces ReduceMin + ArgMin with TopK(largest=0). + + Transformation: + ReduceMin(X, axes=[axis], keepdims=k) + ArgMin(X, axis=axis, keepdims=k) + → TopK(X, k=1, axis=axis, largest=0) [+ Squeeze if k=0] + + When keepdims=0, the output of TopK is squeezed to match the original output shapes. + """ + + @property + def reduce_op_type(self) -> str: + return "ReduceMin" + + @property + def arg_op_type(self) -> str: + return "ArgMin" + + @property + def largest(self) -> int: + return 0 # TopK returns smallest elements + + def pattern(self, op, x): + """Define the pattern to match: ReduceMin and ArgMin on the same input. + + Note: For opset 18+, ReduceMin has a second input for axes, which we allow + but will validate in check() to ensure it's a constant. + """ + reduce_val = op.ReduceMin(x, _allow_other_inputs=True, _outputs=["reduce_val"]) + arg_idx = op.ArgMin(x, _outputs=["arg_idx"]) + return reduce_val, arg_idx + + +reduce_max_argmax_to_topk_rule = FuseReduceMaxArgMaxToTopK().rule() +reduce_min_argmin_to_topk_rule = FuseReduceMinArgMinToTopK().rule() + +rules = RewriteRuleSet([reduce_max_argmax_to_topk_rule, reduce_min_argmin_to_topk_rule]) diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py new file mode 100644 index 0000000000..b9a3594de6 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py @@ -0,0 +1,543 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +from onnx_ir.passes.common import onnx_checker, shape_inference +from parameterized import parameterized + +from onnxscript.rewriter import MatchingTracer, MatchStatus, testing +from onnxscript.rewriter.rules.common._fuse_reduce_arg_to_topk import ( + reduce_max_argmax_to_topk_rule, + reduce_min_argmin_to_topk_rule, + rules, +) + + +class FuseReduceArgToTopKTestBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20260127) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def run_test( + self, + base_model: ir.Model, + expected_op_types: list[str], + ): + onnx_checker.CheckerPass(True)(base_model) + base_model = shape_inference.infer_shapes(base_model) + updated_model = self.clone_model(base_model) + count = rules.apply_to_model(updated_model) + + # Check that the rule was applied + self.assertGreater(count, 0) + + # Check expected op_types + self.assertEqual([node.op_type for node in updated_model.graph], expected_op_types) + + # Check inference + inputs = ( + self.rng.uniform( + low=-10.0, + high=10.0, + size=(2, *updated_model.graph.inputs[0].shape[1:]), + ).astype(np.float32), + ) + + testing.assert_numerically_equal( + base_model, + updated_model, + inputs, + ) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rule][0] + self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class TestFuseReduceMaxArgMaxToTopK(FuseReduceArgToTopKTestBase): + @parameterized.expand( + [ + ("keepdims_1_axis_1", 1, 1), + ("keepdims_1_axis_2", 1, 2), + ("keepdims_1_axis_neg1", 1, -1), + ("keepdims_0_axis_1", 0, 1), + ("keepdims_0_axis_2", 0, 2), + ("keepdims_0_axis_neg1", 0, -1), + ] + ) + def test_successful_fuse_reduce_argmax_to_topk(self, _, keepdims, axis): + """Test fusion of ReduceMax + ArgMax into TopK with various keepdims and axis values.""" + # When keepdims=0, the output rank is reduced by 1 + if keepdims == 0: + output_shape_str = "[N, ?, ?]" + else: + output_shape_str = "[N, ?, ?, ?]" + + # Test with opset 13 (axes as attribute) + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} max_val, int64{output_shape_str} max_idx) + {{ + max_val = ReduceMax(X) + max_idx = ArgMax(X) + }} + """) + + # Expected: Constant for K, TopK, possibly (Constant + Squeeze) x2 for keepdims=0 + if keepdims == 0: + expected_op_types = ["Constant", "TopK", "Constant", "Squeeze", "Squeeze"] + else: + expected_op_types = ["Constant", "TopK"] + + self.run_test(base_model, expected_op_types) + + @parameterized.expand( + [ + ("keepdims_1_axis_1", 1, 1), + ("keepdims_0_axis_2", 0, 2), + ] + ) + def test_successful_fuse_reduce_argmax_to_topk_opset18(self, _, keepdims, axis): + """Test fusion with opset 18+ (axes as input).""" + if keepdims == 0: + output_shape_str = "[N, ?, ?]" + else: + output_shape_str = "[N, ?, ?, ?]" + + # In opset 18+, axes must be passed as the second input to ReduceMax + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} max_val, int64{output_shape_str} max_idx) + + {{ + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + }} + """) + + # Expected: Constant for K, TopK, possibly (Constant + Squeeze) x2 for keepdims=0 + if keepdims == 0: + expected_op_types = ["Constant", "TopK", "Constant", "Squeeze", "Squeeze"] + else: + expected_op_types = ["Constant", "TopK"] + + self.run_test(base_model, expected_op_types) + + def test_fuse_reduce_argmax_explicit_axis_0(self): + """Test fusion with explicit axis=0.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] max_val, int64[1, 14, 17] max_idx) + { + max_val = ReduceMax(X) + max_idx = ArgMax(X) + } + """) + + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_reduce_argmax_mixed_negative_positive_axes(self): + """Test fusion when ReduceMax uses negative axis and ArgMax uses positive axis. + + Input shape is [N, 32, 14, 17], rank is 4. Axis -1 is equivalent to axis 3. + The rule should normalize both axes before comparison. + """ + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, 32, 14, 1] max_val, int64[N, 32, 14, 1] max_idx) + { + max_val = ReduceMax(X) + max_idx = ArgMax(X) + } + """) + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_fail_keepdims_mismatch(self): + """Test that fusion fails when keepdims values don't match.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?] max_idx) + { + max_val = ReduceMax(X) + max_idx = ArgMax(X) + } + """) + + self.run_failed_condition_test( + base_model, reduce_max_argmax_to_topk_rule, "keepdims mismatch" + ) + + def test_fail_axis_mismatch(self): + """Test that fusion fails when axes don't match.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + { + max_val = ReduceMax(X) + max_idx = ArgMax(X) + } + """) + + self.run_failed_condition_test( + base_model, reduce_max_argmax_to_topk_rule, "Axis mismatch" + ) + + def test_fail_multiple_axes_reduce_max(self): + """Test that fusion fails when ReduceMax operates on multiple axes.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + { + max_val = ReduceMax(X) + max_idx = ArgMax(X) + } + """) + + self.run_failed_condition_test( + base_model, + reduce_max_argmax_to_topk_rule, + "ReduceMax must operate on a single axis", + ) + + def test_fail_select_last_index_argmax(self): + """Test that fusion fails when ArgMax has select_last_index=1.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + { + max_val = ReduceMax(X) + max_idx = ArgMax(X) + } + """) + + self.run_failed_condition_test( + base_model, + reduce_max_argmax_to_topk_rule, + "ArgMax has select_last_index=1, which is not supported by TopK.", + ) + + def test_successful_fuse_with_default_keepdims(self): + """Test fusion with default keepdims (should be 1).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + { + max_val = ReduceMax(X) + max_idx = ArgMax(X) + } + """) + + # Both should use default keepdims=1, so fusion should succeed + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_with_default_axis(self): + """Test fusion with default axis (should be 0).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] max_val, int64[1, 14, 17] max_idx) + { + max_val = ReduceMax(X) + max_idx = ArgMax(X) + } + """) + + # ArgMax should use default axis=0, so fusion should succeed + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_with_all_defaults(self): + """Test fusion with all default values (keepdims=1, axis=0).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] max_val, int64[1, 14, 17] max_idx) + { + max_val = ReduceMax(X) + max_idx = ArgMax(X) + } + """) + + # Both should use defaults: keepdims=1, axis=0 + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_no_fusion_different_inputs(self): + """Test that fusion doesn't happen when nodes have different inputs.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X, float[N, 32, 14, 17] Y) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + { + max_val = ReduceMax(X) + max_idx = ArgMax(Y) + } + """) + + # Pattern won't match at all because inputs are different + updated_model = self.clone_model(base_model) + count = rules.apply_to_model(updated_model) + self.assertEqual(count, 0) + + # Model should be unchanged + self.assertEqual( + [node.op_type for node in base_model.graph], + [node.op_type for node in updated_model.graph], + ) + + +class TestFuseReduceMinArgMinToTopK(FuseReduceArgToTopKTestBase): + """Test cases for ReduceMin + ArgMin → TopK(largest=0) fusion.""" + + @parameterized.expand( + [ + ("keepdims_1_axis_1", 1, 1), + ("keepdims_1_axis_2", 1, 2), + ("keepdims_1_axis_neg1", 1, -1), + ("keepdims_0_axis_1", 0, 1), + ("keepdims_0_axis_2", 0, 2), + ("keepdims_0_axis_neg1", 0, -1), + ] + ) + def test_successful_fuse_reduce_argmin_to_topk(self, _, keepdims, axis): + """Test fusion of ReduceMin + ArgMin into TopK with various keepdims and axis values.""" + if keepdims == 0: + output_shape_str = "[N, ?, ?]" + else: + output_shape_str = "[N, ?, ?, ?]" + + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} min_val, int64{output_shape_str} min_idx) + {{ + min_val = ReduceMin(X) + min_idx = ArgMin(X) + }} + """) + + # Expected: Constant for K, TopK, possibly (Constant + Squeeze) x2 for keepdims=0 + if keepdims == 0: + expected_op_types = ["Constant", "TopK", "Constant", "Squeeze", "Squeeze"] + else: + expected_op_types = ["Constant", "TopK"] + + self.run_test(base_model, expected_op_types) + + @parameterized.expand( + [ + ("keepdims_1_axis_1", 1, 1), + ("keepdims_0_axis_2", 0, 2), + ] + ) + def test_successful_fuse_reduce_argmin_to_topk_opset18(self, _, keepdims, axis): + """Test fusion with opset 18+ (axes as input) for Min operations.""" + if keepdims == 0: + output_shape_str = "[N, ?, ?]" + else: + output_shape_str = "[N, ?, ?, ?]" + + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} min_val, int64{output_shape_str} min_idx) + + {{ + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + }} + """) + + if keepdims == 0: + expected_op_types = ["Constant", "TopK", "Constant", "Squeeze", "Squeeze"] + else: + expected_op_types = ["Constant", "TopK"] + + self.run_test(base_model, expected_op_types) + + def test_fuse_reduce_argmin_explicit_axis_0(self): + """Test fusion with explicit axis=0.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] min_val, int64[1, 14, 17] min_idx) + { + min_val = ReduceMin(X) + min_idx = ArgMin(X) + } + """) + + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_reduce_argmin_mixed_axes(self): + """Test fusion with mixed negative/positive axes for Min operations. + + Axis -2 is equivalent to axis 2 for rank-4 tensors. + """ + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, 32, 1, 17] min_val, int64[N, 32, 1, 17] min_idx) + { + min_val = ReduceMin(X) + min_idx = ArgMin(X) + } + """) + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_fail_axis_mismatch(self): + """Test that fusion fails when axes don't match for Min operations.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + { + min_val = ReduceMin(X) + min_idx = ArgMin(X) + } + """) + self.run_failed_condition_test( + base_model, reduce_min_argmin_to_topk_rule, "Axis mismatch" + ) + + def test_fail_keepdims_mismatch(self): + """Test that fusion fails when keepdims values don't match for Min operations.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?] min_idx) + { + min_val = ReduceMin(X) + min_idx = ArgMin(X) + } + """) + self.run_failed_condition_test( + base_model, reduce_min_argmin_to_topk_rule, "keepdims mismatch" + ) + + def test_fail_multiple_axes_reduce_min(self): + """Test that fusion fails when ReduceMin operates on multiple axes.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + { + min_val = ReduceMin(X) + min_idx = ArgMin(X) + } + """) + + self.run_failed_condition_test( + base_model, + reduce_min_argmin_to_topk_rule, + "ReduceMin must operate on a single axis", + ) + + def test_fail_select_last_index_argmin(self): + """Test that fusion fails when ArgMin has select_last_index=1.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + { + min_val = ReduceMin(X) + min_idx = ArgMin(X) + } + """) + self.run_failed_condition_test( + base_model, + reduce_min_argmin_to_topk_rule, + "ArgMin has select_last_index=1, which is not supported by TopK.", + ) + + def test_successful_fuse_with_default_keepdims(self): + """Test fusion with default keepdims (should be 1).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + { + min_val = ReduceMin(X) + min_idx = ArgMin(X) + } + """) + + # Both should use default keepdims=1, so fusion should succeed + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_with_default_axis(self): + """Test fusion with default axis (should be 0).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] min_val, int64[1, 14, 17] min_idx) + { + min_val = ReduceMin(X) + min_idx = ArgMin(X) + } + """) + + # ArgMin should use default axis=0, so fusion should succeed + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_successful_fuse_with_all_defaults(self): + """Test fusion with all default values (keepdims=1, axis=0).""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 14, 17] X) => (float[1, 14, 17] min_val, int64[1, 14, 17] min_idx) + { + min_val = ReduceMin(X) + min_idx = ArgMin(X) + } + """) + + # Both should use defaults: keepdims=1, axis=0 + expected_op_types = ["Constant", "TopK"] + self.run_test(base_model, expected_op_types) + + def test_no_fusion_different_inputs(self): + """Test that fusion doesn't happen when nodes have different inputs.""" + base_model = ir.from_onnx_text(""" + < ir_version: 10, opset_import: ["" : 13] > + test_model (float[N, 32, 14, 17] X, float[N, 32, 14, 17] Y) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + { + min_val = ReduceMin(X) + min_idx = ArgMin(Y) + } + """) + + # Pattern won't match at all because inputs are different + updated_model = self.clone_model(base_model) + count = rules.apply_to_model(updated_model) + self.assertEqual(count, 0) + + # Model should be unchanged + self.assertEqual( + [node.op_type for node in base_model.graph], + [node.op_type for node in updated_model.graph], + ) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index 84a4067b3f..9ee12f3ac3 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -914,58 +914,6 @@ def forward(self, x): ) _testing.assert_onnx_program(onnx_program) - def test_max_dim_negative_dim_squeeze_stability(self): - """Ensure max.dim(dim=-1, keepdim=False) exports and runs correctly. - - TopK + Squeeze(axes=[dim]) receives dim=-1. Validates that ORT handles - negative axis in Squeeze and output shape matches PyTorch. - """ - - class Model(torch.nn.Module): - def forward(self, x): - return torch.max(x, dim=-1, keepdim=False) - - x = torch.randn(2, 3, 4) - onnx_program = torch.onnx.export( - Model(), (x,), dynamo=True, verbose=False - ) - _testing.assert_onnx_program(onnx_program) - - def test_min_dim_negative_dim_squeeze_stability(self): - """Ensure min.dim(dim=-1, keepdim=False) exports and runs correctly. - - Same as max_dim_negative_dim: TopK + Squeeze(axes=[dim]) with dim=-1. - """ - - class Model(torch.nn.Module): - def forward(self, x): - return torch.min(x, dim=-1, keepdim=False) - - x = torch.randn(2, 3, 4) - onnx_program = torch.onnx.export( - Model(), (x,), dynamo=True, verbose=False - ) - _testing.assert_onnx_program(onnx_program) - - def test_max_dim_chained_reduction(self): - """Ensure x.max(dim=1).values.max(dim=0) exports and runs correctly. - - Validates that TopK -> Squeeze -> next TopK -> Squeeze shape flow - is correct when chaining max.dim calls. - """ - - class Model(torch.nn.Module): - def forward(self, x): - v1, _ = x.max(dim=1, keepdim=False) - v2, _ = v1.max(dim=0, keepdim=False) - return v2 - - x = torch.randn(2, 3, 4) - onnx_program = torch.onnx.export( - Model(), (x,), dynamo=True, verbose=False - ) - _testing.assert_onnx_program(onnx_program) - if __name__ == "__main__": unittest.main() From dde8d14ca82d1974746868514700ac1317dfa352 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Tue, 3 Feb 2026 23:19:44 -0800 Subject: [PATCH 4/9] format --- onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py | 1 - .../rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py index f4565663d4..623be62d3d 100644 --- a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py @@ -237,7 +237,6 @@ class FuseReduceMaxArgMaxToTopK(FuseReduceArgToTopKBase): @property def reduce_op_type(self) -> str: return "ReduceMax" - @property def arg_op_type(self) -> str: diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py index b9a3594de6..a2e603314c 100644 --- a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py @@ -540,4 +540,4 @@ def test_no_fusion_different_inputs(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 72fe1dbda6b4eea9682f04c1ab693947df70a97e Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Fri, 6 Feb 2026 22:55:52 -0800 Subject: [PATCH 5/9] solve comments - part 1 --- onnxscript/rewriter/__init__.py | 2 + .../rules/common/_fuse_reduce_arg_to_topk.py | 82 +++++++------------ 2 files changed, 31 insertions(+), 53 deletions(-) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index fb93bc703f..6f9076e3d2 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -41,6 +41,7 @@ _collapse_slices, _fuse_batchnorm, _fuse_pad_into_conv, + _fuse_reduce_arg_to_topk, _fuse_relus_clips, _min_max_to_clip, _no_op, @@ -61,6 +62,7 @@ *_fuse_pad_into_conv.rules, *_fuse_batchnorm.rules, *_remove_optional_bias.rules, + *_fuse_reduce_arg_to_topk.rules, ) diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py index 623be62d3d..2d1291271b 100644 --- a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py @@ -29,7 +29,7 @@ from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet -class FuseReduceArgToTopKBase(RewriteRuleClassBase): +class _FuseReduceArgToTopKBase(RewriteRuleClassBase): """Base class for fusing Reduce{Max,Min} + Arg{Max,Min} into TopK. This base class contains the common logic for checking and rewriting patterns where @@ -38,28 +38,28 @@ class FuseReduceArgToTopKBase(RewriteRuleClassBase): Subclasses must implement: - pattern(): Define the specific Reduce and Arg operations to match - - reduce_op_type: Property returning the name of the Reduce op (e.g., "ReduceMax") - - arg_op_type: Property returning the name of the Arg op (e.g., "ArgMax") - largest: Property returning 1 for Max operations, 0 for Min operations """ - @property - @abstractmethod - def reduce_op_type(self) -> str: - """Return the name of the Reduce operation""" - ... - - @property - @abstractmethod - def arg_op_type(self) -> str: - """Return the name of the Arg operation""" - ... - @property @abstractmethod def largest(self) -> int: """Return 1 for Max operations (largest elements), 0 for Min operations (smallest elements).""" - ... + + @staticmethod + def _normalize_axis(axis: int, rank: int | None) -> int: + """Normalize a potentially negative axis to a positive axis index. + + Args: + axis: The axis to normalize (can be negative). + rank: The rank of the tensor, or None if unknown. + + Returns: + The normalized axis (non-negative if rank is known and axis was negative). + """ + if rank is not None and axis < 0: + return axis + rank + return axis def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: """Check if Reduce and Arg operations can be safely fused into TopK. @@ -84,9 +84,6 @@ def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: reduce_node = reduce_val.producer() arg_node = arg_idx.producer() - if reduce_node is None or arg_node is None: - return check_result.fail("Cannot find producer nodes.") - # Step 1: Get keepdims attribute from both nodes reduce_keepdims_attr = reduce_node.attributes.get("keepdims") arg_keepdims_attr = arg_node.attributes.get("keepdims") @@ -100,8 +97,8 @@ def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: # Step 2: Check if keepdims match if reduce_keepdims != arg_keepdims: return check_result.fail( - f"keepdims mismatch: {self.reduce_op_type} has {reduce_keepdims}, " - f"{self.arg_op_type} has {arg_keepdims}." + f"keepdims mismatch: {reduce_node.op_type} has {reduce_keepdims}, " + f"{arg_node.op_type} has {arg_keepdims}." ) # Step 3: Get axes from Reduce operation @@ -113,29 +110,29 @@ def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: try: axes_list = list(reduce_axes_attr.as_ints()) except Exception: - return check_result.fail(f"Cannot parse {self.reduce_op_type} axes attribute.") + return check_result.fail(f"Cannot parse {reduce_node.op_type} axes attribute.") elif len(reduce_node.inputs) >= 2 and reduce_node.inputs[1] is not None: # Opset 18+: axes is the second input axes_input = reduce_node.inputs[1] axes_const_value = axes_input.const_value if axes_const_value is None: return check_result.fail( - f"{self.reduce_op_type} axes input is not a constant." + f"{reduce_node.op_type} axes input is not a constant." ) try: axes_array = axes_const_value.numpy() axes_list = axes_array.tolist() if axes_array.ndim > 0 else [int(axes_array)] except Exception: - return check_result.fail(f"Cannot parse {self.reduce_op_type} axes input.") + return check_result.fail(f"Cannot parse {reduce_node.op_type} axes input.") else: return check_result.fail( - f"{self.reduce_op_type} axes not found (neither attribute nor input)." + f"{reduce_node.op_type} axes not found (neither attribute nor input)." ) # Step 4: Check that Reduce operates on exactly one axis if len(axes_list) != 1: return check_result.fail( - f"{self.reduce_op_type} must operate on a single axis, got {len(axes_list)} axes." + f"{reduce_node.op_type} must operate on a single axis, got {len(axes_list)} axes." ) reduce_axis = axes_list[0] @@ -150,22 +147,17 @@ def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: select_last_index_attr = arg_node.attributes.get("select_last_index") if select_last_index_attr is not None and select_last_index_attr.as_int() != 0: return check_result.fail( - f"{self.arg_op_type} has select_last_index=1, which is not supported by TopK." + f"{arg_node.op_type} has select_last_index=1, which is not supported by TopK." ) # Step 7: Normalize axes if rank is known (handle negative indices) input_x = reduce_node.inputs[0] rank = len(input_x.shape) if input_x.shape is not None else None - def normalize_axis(axis: int, rank: int | None) -> int: - if rank is not None and axis < 0: - return axis + rank - return axis - - if normalize_axis(reduce_axis, rank) != normalize_axis(arg_axis, rank): + if self._normalize_axis(reduce_axis, rank) != self._normalize_axis(arg_axis, rank): return check_result.fail( - f"Axis mismatch: {self.reduce_op_type} operates on axis {reduce_axis}, " - f"{self.arg_op_type} operates on axis {arg_axis}." + f"Axis mismatch: {reduce_node.op_type} operates on axis {reduce_axis}, " + f"{arg_node.op_type} operates on axis {arg_axis}." ) return check_result @@ -224,7 +216,7 @@ def rewrite(self, op, x, reduce_val, arg_idx): return new_values, new_indices -class FuseReduceMaxArgMaxToTopK(FuseReduceArgToTopKBase): +class FuseReduceMaxArgMaxToTopK(_FuseReduceArgToTopKBase): """Replaces ReduceMax + ArgMax with TopK(largest=1). Transformation: @@ -234,14 +226,6 @@ class FuseReduceMaxArgMaxToTopK(FuseReduceArgToTopKBase): When keepdims=0, the output of TopK is squeezed to match the original output shapes. """ - @property - def reduce_op_type(self) -> str: - return "ReduceMax" - - @property - def arg_op_type(self) -> str: - return "ArgMax" - @property def largest(self) -> int: return 1 # TopK returns largest elements @@ -257,7 +241,7 @@ def pattern(self, op, x): return reduce_val, arg_idx -class FuseReduceMinArgMinToTopK(FuseReduceArgToTopKBase): +class FuseReduceMinArgMinToTopK(_FuseReduceArgToTopKBase): """Replaces ReduceMin + ArgMin with TopK(largest=0). Transformation: @@ -267,14 +251,6 @@ class FuseReduceMinArgMinToTopK(FuseReduceArgToTopKBase): When keepdims=0, the output of TopK is squeezed to match the original output shapes. """ - @property - def reduce_op_type(self) -> str: - return "ReduceMin" - - @property - def arg_op_type(self) -> str: - return "ArgMin" - @property def largest(self) -> int: return 0 # TopK returns smallest elements From 301635e92d260686582b54e31022c3059cb23f18 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Fri, 6 Feb 2026 23:33:49 -0800 Subject: [PATCH 6/9] resolve comments from copilot --- .../rules/common/_fuse_reduce_arg_to_topk.py | 6 +- .../common/_fuse_reduce_arg_to_topk_test.py | 80 ++++++++++++++++++- 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py index 2d1291271b..57461ff7ad 100644 --- a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py @@ -114,13 +114,13 @@ def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: elif len(reduce_node.inputs) >= 2 and reduce_node.inputs[1] is not None: # Opset 18+: axes is the second input axes_input = reduce_node.inputs[1] - axes_const_value = axes_input.const_value - if axes_const_value is None: + axes_tensor = ir.convenience.get_const_tensor(axes_input) + if axes_tensor is None: return check_result.fail( f"{reduce_node.op_type} axes input is not a constant." ) try: - axes_array = axes_const_value.numpy() + axes_array = axes_tensor.numpy() axes_list = axes_array.tolist() if axes_array.ndim > 0 else [int(axes_array)] except Exception: return check_result.fail(f"Cannot parse {reduce_node.op_type} axes input.") diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py index a2e603314c..dc2ab3f7c1 100644 --- a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py @@ -77,7 +77,7 @@ def run_failed_condition_test( # Check that the error message is the expected one tracer_match = tracer.best_matches_map[rule][0] - self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertEqual(tracer_match.status, MatchStatus.CONDITION_FAILED) self.assertRegex(tracer_match.match_result.reason, expected_message) @@ -164,6 +164,45 @@ def test_fuse_reduce_argmax_explicit_axis_0(self): expected_op_types = ["Constant", "TopK"] self.run_test(base_model, expected_op_types) + @parameterized.expand( + [ + ("keepdims_1_axis_2", 1, 2), + ("keepdims_0_axis_2", 0, 2), + ] + ) + def test_fuse_reduce_argmax_axes_from_constant_node(self, _, keepdims, axis): + """Test fusion when axes come from a Constant node (opset 18+).""" + if keepdims == 0: + output_shape_str = "[N, ?, ?]" + else: + output_shape_str = "[N, ?, ?, ?]" + + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} max_val, int64{output_shape_str} max_idx) + {{ + axes = Constant() + max_val = ReduceMax(X, axes) + max_idx = ArgMax(X) + }} + """) + + # Expected: Constant for axes, Constant for K, TopK, + # possibly (Constant + Squeeze) x2 for keepdims=0 + if keepdims == 0: + expected_op_types = [ + "Constant", + "Constant", + "TopK", + "Constant", + "Squeeze", + "Squeeze", + ] + else: + expected_op_types = ["Constant", "Constant", "TopK"] + + self.run_test(base_model, expected_op_types) + def test_successful_fuse_reduce_argmax_mixed_negative_positive_axes(self): """Test fusion when ReduceMax uses negative axis and ArgMax uses positive axis. @@ -394,6 +433,45 @@ def test_fuse_reduce_argmin_explicit_axis_0(self): expected_op_types = ["Constant", "TopK"] self.run_test(base_model, expected_op_types) + @parameterized.expand( + [ + ("keepdims_1_axis_2", 1, 2), + ("keepdims_0_axis_2", 0, 2), + ] + ) + def test_fuse_reduce_argmin_axes_from_constant_node(self, _, keepdims, axis): + """Test fusion when axes come from a Constant node for Min operations (opset 18+).""" + if keepdims == 0: + output_shape_str = "[N, ?, ?]" + else: + output_shape_str = "[N, ?, ?, ?]" + + base_model = ir.from_onnx_text(f""" + < ir_version: 10, opset_import: ["" : 18] > + test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} min_val, int64{output_shape_str} min_idx) + {{ + axes = Constant() + min_val = ReduceMin(X, axes) + min_idx = ArgMin(X) + }} + """) + + # Expected: Constant for axes, Constant for K, TopK, + # possibly (Constant + Squeeze) x2 for keepdims=0 + if keepdims == 0: + expected_op_types = [ + "Constant", + "Constant", + "TopK", + "Constant", + "Squeeze", + "Squeeze", + ] + else: + expected_op_types = ["Constant", "Constant", "TopK"] + + self.run_test(base_model, expected_op_types) + def test_successful_fuse_reduce_argmin_mixed_axes(self): """Test fusion with mixed negative/positive axes for Min operations. From 81fa7135b0e29ef53c3e82db083a0bb7113cb597 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Thu, 12 Feb 2026 18:09:13 -0800 Subject: [PATCH 7/9] solve comments - part3 --- onnxscript/rewriter/__init__.py | 2 - .../rules/common/_fuse_reduce_arg_to_topk.py | 47 ++--- .../common/_fuse_reduce_arg_to_topk_test.py | 162 +++++++----------- 3 files changed, 70 insertions(+), 141 deletions(-) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 6f9076e3d2..fb93bc703f 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -41,7 +41,6 @@ _collapse_slices, _fuse_batchnorm, _fuse_pad_into_conv, - _fuse_reduce_arg_to_topk, _fuse_relus_clips, _min_max_to_clip, _no_op, @@ -62,7 +61,6 @@ *_fuse_pad_into_conv.rules, *_fuse_batchnorm.rules, *_remove_optional_bias.rules, - *_fuse_reduce_arg_to_topk.rules, ) diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py index 57461ff7ad..c6f541aaac 100644 --- a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py @@ -6,16 +6,11 @@ - ReduceMax(X, axes=[axis], keepdims=k) + ArgMax(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=1) [+ Squeeze if k=0] - ReduceMin(X, axes=[axis], keepdims=k) + ArgMin(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=0) [+ Squeeze if k=0] -Supports both ONNX opset versions: - - Opset 13-17: Reduce{Max,Min} with axes as an attribute - - Opset 18+: Reduce{Max,Min} with axes as a second input - Constraints: - Both nodes must operate on the same input X. - Both nodes must target the same axis. - Both nodes must have the same keepdims attribute value. - The Reduce node must operate on a single axis (len(axes) == 1). - - For opset 18+, the Reduce node's axes input must be a constant. """ from __future__ import annotations @@ -85,14 +80,8 @@ def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: arg_node = arg_idx.producer() # Step 1: Get keepdims attribute from both nodes - reduce_keepdims_attr = reduce_node.attributes.get("keepdims") - arg_keepdims_attr = arg_node.attributes.get("keepdims") - - # ONNX default: keepdims = 1 for both Reduce and Arg operations - reduce_keepdims = ( - reduce_keepdims_attr.as_int() if reduce_keepdims_attr is not None else 1 - ) - arg_keepdims = arg_keepdims_attr.as_int() if arg_keepdims_attr is not None else 1 + reduce_keepdims = reduce_node.attributes.get_int("keepdims", 1) + arg_keepdims = arg_node.attributes.get_int("keepdims", 1) # Step 2: Check if keepdims match if reduce_keepdims != arg_keepdims: @@ -101,18 +90,8 @@ def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: f"{arg_node.op_type} has {arg_keepdims}." ) - # Step 3: Get axes from Reduce operation - # In opset 18+, axes is an input; in opset 13-17, it's an attribute - reduce_axes_attr = reduce_node.attributes.get("axes") - - if reduce_axes_attr is not None: - # Opset 13-17: axes is an attribute - try: - axes_list = list(reduce_axes_attr.as_ints()) - except Exception: - return check_result.fail(f"Cannot parse {reduce_node.op_type} axes attribute.") - elif len(reduce_node.inputs) >= 2 and reduce_node.inputs[1] is not None: - # Opset 18+: axes is the second input + # Step 3: Get axes from Reduce ops's inputs + if len(reduce_node.inputs) >= 2 and reduce_node.inputs[1] is not None: axes_input = reduce_node.inputs[1] axes_tensor = ir.convenience.get_const_tensor(axes_input) if axes_tensor is None: @@ -125,9 +104,7 @@ def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: except Exception: return check_result.fail(f"Cannot parse {reduce_node.op_type} axes input.") else: - return check_result.fail( - f"{reduce_node.op_type} axes not found (neither attribute nor input)." - ) + return check_result.fail(f"{reduce_node.op_type} axes not found in inputs.") # Step 4: Check that Reduce operates on exactly one axis if len(axes_list) != 1: @@ -139,13 +116,12 @@ def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: # Step 5: Get axis from Arg operation # ONNX default: axis = 0 for ArgMax/ArgMin - arg_axis_attr = arg_node.attributes.get("axis") - arg_axis = arg_axis_attr.as_int() if arg_axis_attr is not None else 0 + arg_axis = arg_node.attributes.get_int("axis", 0) # Step 6: Check select_last_index attribute (if present) # TopK always returns the first occurrence in case of ties - select_last_index_attr = arg_node.attributes.get("select_last_index") - if select_last_index_attr is not None and select_last_index_attr.as_int() != 0: + select_last_index_attr = arg_node.attributes.get_int("select_last_index", 0) + if select_last_index_attr != 0: return check_result.fail( f"{arg_node.op_type} has select_last_index=1, which is not supported by TopK." ) @@ -178,11 +154,8 @@ def rewrite(self, op, x, reduce_val, arg_idx): arg_node = arg_idx.producer() # Step 2: Extract necessary attributes with ONNX default values - axis_attr = arg_node.attributes.get("axis") - keepdims_attr = arg_node.attributes.get("keepdims") - - axis = axis_attr.as_int() if axis_attr is not None else 0 - keepdims = keepdims_attr.as_int() if keepdims_attr is not None else 1 + axis = arg_node.attributes.get_int("axis", 0) + keepdims = arg_node.attributes.get_int("keepdims", 1) # Step 2b: Normalize axis (convert negative to positive) if rank is known if axis < 0 and x.shape is not None: diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py index dc2ab3f7c1..58ed6dec4e 100644 --- a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk_test.py @@ -100,38 +100,6 @@ def test_successful_fuse_reduce_argmax_to_topk(self, _, keepdims, axis): else: output_shape_str = "[N, ?, ?, ?]" - # Test with opset 13 (axes as attribute) - base_model = ir.from_onnx_text(f""" - < ir_version: 10, opset_import: ["" : 13] > - test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} max_val, int64{output_shape_str} max_idx) - {{ - max_val = ReduceMax(X) - max_idx = ArgMax(X) - }} - """) - - # Expected: Constant for K, TopK, possibly (Constant + Squeeze) x2 for keepdims=0 - if keepdims == 0: - expected_op_types = ["Constant", "TopK", "Constant", "Squeeze", "Squeeze"] - else: - expected_op_types = ["Constant", "TopK"] - - self.run_test(base_model, expected_op_types) - - @parameterized.expand( - [ - ("keepdims_1_axis_1", 1, 1), - ("keepdims_0_axis_2", 0, 2), - ] - ) - def test_successful_fuse_reduce_argmax_to_topk_opset18(self, _, keepdims, axis): - """Test fusion with opset 18+ (axes as input).""" - if keepdims == 0: - output_shape_str = "[N, ?, ?]" - else: - output_shape_str = "[N, ?, ?, ?]" - - # In opset 18+, axes must be passed as the second input to ReduceMax base_model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} max_val, int64{output_shape_str} max_idx) @@ -153,10 +121,11 @@ def test_successful_fuse_reduce_argmax_to_topk_opset18(self, _, keepdims, axis): def test_fuse_reduce_argmax_explicit_axis_0(self): """Test fusion with explicit axis=0.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 14, 17] X) => (float[1, 14, 17] max_val, int64[1, 14, 17] max_idx) + { - max_val = ReduceMax(X) + max_val = ReduceMax(X, axes) max_idx = ArgMax(X) } """) @@ -210,10 +179,11 @@ def test_successful_fuse_reduce_argmax_mixed_negative_positive_axes(self): The rule should normalize both axes before comparison. """ base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, 32, 14, 1] max_val, int64[N, 32, 14, 1] max_idx) + { - max_val = ReduceMax(X) + max_val = ReduceMax(X, axes) max_idx = ArgMax(X) } """) @@ -223,10 +193,11 @@ def test_successful_fuse_reduce_argmax_mixed_negative_positive_axes(self): def test_fail_keepdims_mismatch(self): """Test that fusion fails when keepdims values don't match.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?] max_idx) + { - max_val = ReduceMax(X) + max_val = ReduceMax(X, axes) max_idx = ArgMax(X) } """) @@ -238,10 +209,11 @@ def test_fail_keepdims_mismatch(self): def test_fail_axis_mismatch(self): """Test that fusion fails when axes don't match.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + { - max_val = ReduceMax(X) + max_val = ReduceMax(X, axes) max_idx = ArgMax(X) } """) @@ -253,10 +225,11 @@ def test_fail_axis_mismatch(self): def test_fail_multiple_axes_reduce_max(self): """Test that fusion fails when ReduceMax operates on multiple axes.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + { - max_val = ReduceMax(X) + max_val = ReduceMax(X, axes) max_idx = ArgMax(X) } """) @@ -270,10 +243,11 @@ def test_fail_multiple_axes_reduce_max(self): def test_fail_select_last_index_argmax(self): """Test that fusion fails when ArgMax has select_last_index=1.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + { - max_val = ReduceMax(X) + max_val = ReduceMax(X, axes) max_idx = ArgMax(X) } """) @@ -287,10 +261,11 @@ def test_fail_select_last_index_argmax(self): def test_successful_fuse_with_default_keepdims(self): """Test fusion with default keepdims (should be 1).""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + { - max_val = ReduceMax(X) + max_val = ReduceMax(X, axes) max_idx = ArgMax(X) } """) @@ -302,10 +277,11 @@ def test_successful_fuse_with_default_keepdims(self): def test_successful_fuse_with_default_axis(self): """Test fusion with default axis (should be 0).""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 14, 17] X) => (float[1, 14, 17] max_val, int64[1, 14, 17] max_idx) + { - max_val = ReduceMax(X) + max_val = ReduceMax(X, axes) max_idx = ArgMax(X) } """) @@ -317,10 +293,11 @@ def test_successful_fuse_with_default_axis(self): def test_successful_fuse_with_all_defaults(self): """Test fusion with all default values (keepdims=1, axis=0).""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 14, 17] X) => (float[1, 14, 17] max_val, int64[1, 14, 17] max_idx) + { - max_val = ReduceMax(X) + max_val = ReduceMax(X, axes) max_idx = ArgMax(X) } """) @@ -332,10 +309,11 @@ def test_successful_fuse_with_all_defaults(self): def test_no_fusion_different_inputs(self): """Test that fusion doesn't happen when nodes have different inputs.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X, float[N, 32, 14, 17] Y) => (float[N, ?, ?, ?] max_val, int64[N, ?, ?, ?] max_idx) + { - max_val = ReduceMax(X) + max_val = ReduceMax(X, axes) max_idx = ArgMax(Y) } """) @@ -372,36 +350,6 @@ def test_successful_fuse_reduce_argmin_to_topk(self, _, keepdims, axis): else: output_shape_str = "[N, ?, ?, ?]" - base_model = ir.from_onnx_text(f""" - < ir_version: 10, opset_import: ["" : 13] > - test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} min_val, int64{output_shape_str} min_idx) - {{ - min_val = ReduceMin(X) - min_idx = ArgMin(X) - }} - """) - - # Expected: Constant for K, TopK, possibly (Constant + Squeeze) x2 for keepdims=0 - if keepdims == 0: - expected_op_types = ["Constant", "TopK", "Constant", "Squeeze", "Squeeze"] - else: - expected_op_types = ["Constant", "TopK"] - - self.run_test(base_model, expected_op_types) - - @parameterized.expand( - [ - ("keepdims_1_axis_1", 1, 1), - ("keepdims_0_axis_2", 0, 2), - ] - ) - def test_successful_fuse_reduce_argmin_to_topk_opset18(self, _, keepdims, axis): - """Test fusion with opset 18+ (axes as input) for Min operations.""" - if keepdims == 0: - output_shape_str = "[N, ?, ?]" - else: - output_shape_str = "[N, ?, ?, ?]" - base_model = ir.from_onnx_text(f""" < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float{output_shape_str} min_val, int64{output_shape_str} min_idx) @@ -422,10 +370,11 @@ def test_successful_fuse_reduce_argmin_to_topk_opset18(self, _, keepdims, axis): def test_fuse_reduce_argmin_explicit_axis_0(self): """Test fusion with explicit axis=0.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 14, 17] X) => (float[1, 14, 17] min_val, int64[1, 14, 17] min_idx) + { - min_val = ReduceMin(X) + min_val = ReduceMin(X, axes) min_idx = ArgMin(X) } """) @@ -478,10 +427,11 @@ def test_successful_fuse_reduce_argmin_mixed_axes(self): Axis -2 is equivalent to axis 2 for rank-4 tensors. """ base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, 32, 1, 17] min_val, int64[N, 32, 1, 17] min_idx) + { - min_val = ReduceMin(X) + min_val = ReduceMin(X, axes) min_idx = ArgMin(X) } """) @@ -491,10 +441,11 @@ def test_successful_fuse_reduce_argmin_mixed_axes(self): def test_fail_axis_mismatch(self): """Test that fusion fails when axes don't match for Min operations.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + { - min_val = ReduceMin(X) + min_val = ReduceMin(X, axes) min_idx = ArgMin(X) } """) @@ -505,10 +456,11 @@ def test_fail_axis_mismatch(self): def test_fail_keepdims_mismatch(self): """Test that fusion fails when keepdims values don't match for Min operations.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?] min_idx) + { - min_val = ReduceMin(X) + min_val = ReduceMin(X, axes) min_idx = ArgMin(X) } """) @@ -519,10 +471,11 @@ def test_fail_keepdims_mismatch(self): def test_fail_multiple_axes_reduce_min(self): """Test that fusion fails when ReduceMin operates on multiple axes.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + { - min_val = ReduceMin(X) + min_val = ReduceMin(X, axes) min_idx = ArgMin(X) } """) @@ -536,10 +489,11 @@ def test_fail_multiple_axes_reduce_min(self): def test_fail_select_last_index_argmin(self): """Test that fusion fails when ArgMin has select_last_index=1.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + { - min_val = ReduceMin(X) + min_val = ReduceMin(X, axes) min_idx = ArgMin(X) } """) @@ -552,10 +506,11 @@ def test_fail_select_last_index_argmin(self): def test_successful_fuse_with_default_keepdims(self): """Test fusion with default keepdims (should be 1).""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + { - min_val = ReduceMin(X) + min_val = ReduceMin(X, axes) min_idx = ArgMin(X) } """) @@ -567,10 +522,11 @@ def test_successful_fuse_with_default_keepdims(self): def test_successful_fuse_with_default_axis(self): """Test fusion with default axis (should be 0).""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 14, 17] X) => (float[1, 14, 17] min_val, int64[1, 14, 17] min_idx) + { - min_val = ReduceMin(X) + min_val = ReduceMin(X, axes) min_idx = ArgMin(X) } """) @@ -582,10 +538,11 @@ def test_successful_fuse_with_default_axis(self): def test_successful_fuse_with_all_defaults(self): """Test fusion with all default values (keepdims=1, axis=0).""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 14, 17] X) => (float[1, 14, 17] min_val, int64[1, 14, 17] min_idx) + { - min_val = ReduceMin(X) + min_val = ReduceMin(X, axes) min_idx = ArgMin(X) } """) @@ -597,10 +554,11 @@ def test_successful_fuse_with_all_defaults(self): def test_no_fusion_different_inputs(self): """Test that fusion doesn't happen when nodes have different inputs.""" base_model = ir.from_onnx_text(""" - < ir_version: 10, opset_import: ["" : 13] > + < ir_version: 10, opset_import: ["" : 18] > test_model (float[N, 32, 14, 17] X, float[N, 32, 14, 17] Y) => (float[N, ?, ?, ?] min_val, int64[N, ?, ?, ?] min_idx) + { - min_val = ReduceMin(X) + min_val = ReduceMin(X, axes) min_idx = ArgMin(Y) } """) From 264aed2f5ae76fbc8f6881a20ee23bf9dea2c414 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Thu, 12 Feb 2026 23:03:48 -0800 Subject: [PATCH 8/9] support both static and symbolic shape --- onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py index c6f541aaac..ee52d63d79 100644 --- a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py @@ -128,7 +128,7 @@ def check(self, context, reduce_val, arg_idx, **_) -> MatchResult: # Step 7: Normalize axes if rank is known (handle negative indices) input_x = reduce_node.inputs[0] - rank = len(input_x.shape) if input_x.shape is not None else None + rank = input_x.shape.rank() if input_x.shape is not None else None if self._normalize_axis(reduce_axis, rank) != self._normalize_axis(arg_axis, rank): return check_result.fail( @@ -159,7 +159,7 @@ def rewrite(self, op, x, reduce_val, arg_idx): # Step 2b: Normalize axis (convert negative to positive) if rank is known if axis < 0 and x.shape is not None: - axis = len(x.shape) + axis + axis = x.shape.rank() + axis # Step 3: Create K constant k_constant = op.Constant(value=ir.tensor(np.array([1], dtype=np.int64))) From 09eefce23c74f685e710ced6677d7130a787f1e3 Mon Sep 17 00:00:00 2001 From: Daniel Tu Date: Fri, 13 Feb 2026 20:55:23 -0800 Subject: [PATCH 9/9] add a comment to indicate the version constraints --- onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py index ee52d63d79..aa2e03b0d4 100644 --- a/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py +++ b/onnxscript/rewriter/rules/common/_fuse_reduce_arg_to_topk.py @@ -7,6 +7,7 @@ - ReduceMin(X, axes=[axis], keepdims=k) + ArgMin(X, axis=axis, keepdims=k) → TopK(X, k=1, axis=axis, largest=0) [+ Squeeze if k=0] Constraints: + - This rule only works for opset 18+. - Both nodes must operate on the same input X. - Both nodes must target the same axis. - Both nodes must have the same keepdims attribute value.